mmengine.registry
是一个在多个开源项目(如 MMEngine 和 MM系列工具包,例如 MMDetection、MMClassification 等)中使用的模块化管理机制,旨在通过注册表动态管理不同的组件,例如模型、算法、损失函数或插件等。这个机制极大地增强了项目的灵活性和扩展性。
1 基本概念
1.1 Registry 类
Registry是一个存储映射关系的容器,它将字符串类型的键(通常是组件的名字)映射到具体的 Python 对象(如类或函数)。使用注册表的主要优势是可以在运行时动态创建对象,从而实现高度可配置和可扩展的设计。
1.2 核心功能
注册机制:
提供了一个 register_module
方法,允许开发者将类或函数注册到注册表中。这通常通过装饰器的形式实现,使得代码更加简洁和直观。注册时可以指定一个或多个名称作为键,关联到相应的 Python 类或函数。
动态创建实例:
注册表可以使用 build
方法根据配置动态创建类的实例。这通常需要配置信息包括键名 type
(对应注册的名称)和其他用于初始化对象的参数。通过读取配置文件(通常是 JSON 或 YAML 格式),可以在不修改代码的情况下轻松切换使用的组件或调整参数。
作用域管理:
在某些复杂的应用场景中,Registry
可能需要处理多个域(scope)下的注册问题,例如不同的模块可能需要有各自独立的注册表。Registry
类可以通过指定 scope
参数来实现作用域管理,使得同一个名字在不同的作用域下可以关联到不同的对象。
继承:
注册表可以设置父注册表,使得查找过程可以在当前注册表未找到对应键值时回溯到父注册表中查找。这为模块间的依赖提供了方便,也允许更灵活的重载和扩展。
2 实现示例
2.1 简化的Registry类实现
展示了如何定义这个类,以及如何注册和创建对象:
class Registry: def __init__(self, name, scope=None, parent=None): self.name = name self.scope = scope self.parent = parent self._module_dict = {} def register_module(self, name=None): def _register(cls): module_name = name or cls.__name__ if module_name in self._module_dict: raise KeyError(f'{module_name} is already registered in {self.scope}::{self.name}') self._module_dict[module_name] = cls return cls return _register def get(self, name): if name in self._module_dict: return self._module_dict[name] elif self.parent: return self.parent.get(name) else: raise KeyError(f'{name} is not registered in {self.scope}::{self.name} and no parent registry to fallback.') def build(self, cfg): module_name = cfg['type'] module_cls = self.get(module_name) return module_cls(**{k: v for k, v in cfg.items() if k != 'type'})
2.2 这个注册表类
from torch import nn # 创建父注册表 MODELS = Registry('models', scope='mmengine') # 注册一个模块到父注册表 @MODELS.register_module() class ResNet(nn.Module): def __init__(self, layers): self.layers = layers def forward(self, x): return x # 创建子注册表,指定 MODELS 为父注册表 DETECTORS = Registry('detectors', scope='mmengine', parent=MODELS) # 注册一个模块只到子注册表 @DETECTORS.register_module() class FasterRCNN(nn.Module): def __init__(self, num_classes): self.num_classes = num_classes def forward(self, x): return x # 从子注册表构建 ResNet 实例,尽管它是在父注册表中注册的 resnet_instance = DETECTORS.build({'type': 'ResNet', 'layers': 50}) print(resnet_instance) # 直接从子注册表构建 FasterRCNN 实例 fasterrcnn_instance = DETECTORS.build({'type': 'FasterRCNN', 'num_classes': 80}) print(fasterrcnn_instance)
到此这篇关于python如何通过注册表动态管理组件的文章就介绍到这了,更多相关python动态管理组件内容请搜索IT俱乐部以前的文章或继续浏览下面的相关文章希望大家以后多多支持IT俱乐部!