mindformers.tools.register.register 源代码
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Class Register Module For MindFormers."""
import inspect
import os
from mindformers.tools.hub.dynamic_module_utils import get_class_from_dynamic_module
from mindformers.tools.utils import get_context
NEW_CLASS_PREFIX = "mcore_"
def get_legacy():
"""return mf context: use_legacy"""
try:
legacy = get_context("use_legacy")
use_legacy = bool(legacy is None or legacy)
return use_legacy
except RuntimeError:
return True
[文档]class MindFormerModuleType:
"""
Enumerated class of the MindFormers module type, which includes:
.. list-table::
:widths: 25 25
:header-rows: 1
* - Enumeration Name
- Value
* - CALLBACK
- 'CALLBACK'
* - CONFIG
- 'config'
* - CONTEXT
- 'context'
* - DATASET
- 'dataset'
* - DATASET_LOADER
- 'dataset_loader'
* - DATASET_SAMPLER
- 'dataset_sampler'
* - DATA_HANDLER
- 'data_handler'
* - ENCODER
- 'encoder'
* - FEATURE_EXTRACTOR
- 'feature_extractor'
* - LOSS
- 'loss'
* - LR
- 'lr'
* - MASK_POLICY
- 'mask_policy'
* - METRIC
- 'metric'
* - MODELS
- 'models'
* - OPTIMIZER
- 'optimizer'
* - PIPELINE
- 'pipeline'
* - PROCESSOR
- 'processor'
* - TOKENIZER
- 'tokenizer'
* - TOOLS
- 'tools'
* - TRAINER
- 'trainer'
* - TRANSFORMS
- 'transforms'
* - WRAPPER
- 'wrapper'
Examples:
>>> from mindformers.tools import MindFormerModuleType
>>>
>>> print(MindFormerModuleType.MODULES)
modules
"""
def __init__(self):
pass
TRAINER = 'trainer'
PIPELINE = 'pipeline'
PROCESSOR = 'processor'
TOKENIZER = 'tokenizer'
DATASET = 'dataset'
MASK_POLICY = 'mask_policy'
DATASET_LOADER = 'dataset_loader'
DATASET_SAMPLER = 'dataset_sampler'
TRANSFORMS = 'transforms'
ENCODER = 'encoder'
MODELS = 'models'
MODULES = 'modules'
TRANSFORMER = 'transformer'
BASE_LAYER = 'base_layer'
CORE = 'core'
HEAD = 'head'
LOSS = 'loss'
LR = 'lr'
OPTIMIZER = 'optimizer'
CONTEXT = 'context'
CALLBACK = 'callback'
WRAPPER = 'wrapper'
METRIC = 'metric'
CONFIG = 'config'
TOOLS = 'tools'
FEATURE_EXTRACTOR = 'feature_extractor'
DATA_HANDLER = 'data_handler'
[文档]class MindFormerRegister:
"""
The registration interface for MindFormers, provides methods for registering and obtaining the interface.
Examples:
>>> from mindformers.tools import MindFormerModuleType, MindFormerRegister
>>>
>>>
>>> # Using decorator to register the class
>>> @MindFormerRegister.register(MindFormerModuleType.CONFIG)
>>> class MyConfig:
... def __init__(self, param):
... self.param = param
>>>
>>>
>>> # Using method to register the class
>>> MindFormerRegister.register_cls(register_class=MyConfig, module_type=MindFormerRegister)
>>>
>>> print(MindFormerRegister.is_exist(module_type=MindFormerModuleType.CONFIG, class_name="MyConfig"))
True
>>> cls = MindFormerRegister.get_cls(module_type=MindFormerModuleType.CONFIG, class_name="MyConfig")
>>> print(cls.__name__)
MyConfig
>>> instance = MindFormerRegister.get_instance_from_cfg(cfg={'type': 'MyConfig', 'param': 0},
... module_type=MindFormerModuleType.CONFIG)
>>> print(instance.__class__.__name__)
MyConfig
>>> print(instance.param)
0
>>> instance = MindFormerRegister.get_instance(module_type=MindFormerModuleType.CONFIG,
... class_name="MyConfig",
... param=0)
>>> print(instance.__class__.__name__)
MyConfig
>>> print(instance.param)
0
"""
def __init__(self):
pass
registry = {}
search_names_map = {}
[文档] @classmethod
def register(cls, module_type=MindFormerModuleType.TOOLS, alias=None, legacy=True, search_names=None):
"""
A decorator that registers the class in the registry.
Args:
module_type (MindFormerModuleType, optional): Module type name of MindFormers.
Default: ``MindFormerModuleType.TOOLS``.
alias (str, optional): Alias for the class. Default: ``None``.
legacy (bool, optional): Legacy Class or not. Default: ``True``.
search_names (Union[str, tuple, list, set], optional): mapping search_names to a class_name.
Default: ``None``.
Returns:
Wrapper, decorates the registered class.
"""
def wrapper(register_class):
"""Register-Class with wrapper function.
Args:
register_class : class need to register
Returns:
wrapper of register_class
"""
class_name = alias if alias is not None else register_class.__name__
class_name = cls._add_class_name_prefix(module_type, class_name, legacy)
if module_type not in cls.registry:
cls.registry[module_type] = {class_name: register_class}
else:
cls.registry[module_type][class_name] = register_class
names = set()
if search_names is not None:
if isinstance(search_names, str):
names.add(search_names)
elif isinstance(search_names, (list, tuple, set)):
names.update(search_names)
names.add(class_name)
for search_name in names:
search_name = cls._add_class_name_prefix(module_type, search_name, legacy)
cls.search_names_map[(module_type, search_name)] = class_name
return register_class
return wrapper
[文档] @classmethod
def register_cls(
cls, register_class, module_type=MindFormerModuleType.TOOLS,
alias=None, legacy=True, search_names=None
):
"""
A method that registers a class into the registry.
Args:
register_class (type): The class that need to be registered.
module_type (MindFormerModuleType, optional): Module type name of MindFormers.
Default: ``MindFormerModuleType.TOOLS``.
alias (str, optional): Alias for the class. Default: ``None``.
legacy (bool, optional): Legacy Class or not. Default: ``True``.
search_names (Union[str, tuple, list, set], optional): mapping search_names to a class_name.
Default: ``None``.
Returns:
Class, the registered class itself.
"""
class_name = alias if alias is not None else register_class.__name__
class_name = cls._add_class_name_prefix(module_type, class_name, legacy)
if module_type not in cls.registry:
cls.registry[module_type] = {class_name: register_class}
else:
cls.registry[module_type][class_name] = register_class
names = set()
if search_names is not None:
if isinstance(search_names, str):
names.add(search_names)
elif isinstance(search_names, (list, tuple, set)):
names.update(search_names)
names.add(class_name)
for search_name in names:
search_name = cls._add_class_name_prefix(module_type, search_name, legacy)
cls.search_names_map[(module_type, search_name)] = class_name
return register_class
[文档] @classmethod
def is_exist(cls, module_type, class_name=None):
"""
Determines whether the given class name is in the current type group. If `class_name` is not given,
determines if the given class name is in the current registered dictionary.
Args:
module_type (MindFormerModuleType): Module type name of MindFormers.
class_name (str, optional): Class name. Default: ``None``.
Returns:
A boolean value, indicating whether it exists or not.
"""
if not class_name:
return module_type in cls.registry
class_name = cls._add_class_name_prefix(module_type, class_name, get_legacy())
if (module_type, class_name) in cls.search_names_map:
class_name = cls.search_names_map[(module_type, class_name)]
return module_type in cls.registry and class_name in cls.registry.get(module_type)
registered = module_type in cls.registry and class_name in cls.registry.get(module_type)
return registered
[文档] @classmethod
def get_cls(cls, module_type, class_name=None):
"""
Get the class from the registry.
Args:
module_type (MindFormerModuleType): Module type name of MindFormers.
class_name (str, optional): Class name. Default: ``None``.
Returns:
A registered class.
Raises:
ValueError: Can't find class `class_name` of type `module_type` in the registry.
ValueError: Can't find type `module_type` in the registry.
"""
if not cls.is_exist(module_type, class_name):
raise ValueError(f"Can't find class type {module_type} class name {class_name} in class registry "
f"when use_legacy={get_legacy()}")
if not class_name:
raise ValueError(f"Can't find class. class type = {class_name}")
class_name = cls._add_class_name_prefix(module_type, class_name, get_legacy())
if (module_type, class_name) in cls.search_names_map:
class_name = cls.search_names_map[(module_type, class_name)]
if not (module_type in cls.registry and class_name in cls.registry.get(module_type)):
raise ValueError(f"Can't find class type {module_type} class name {class_name} in class registry "
f"when use_legacy={get_legacy()}")
register_class = cls.registry.get(module_type).get(class_name)
return register_class
@classmethod
def get_instance_type_from_cfg(cls, cfg, module_type=MindFormerModuleType.MODELS):
"""
Get instance's type of the class in the registry via configuration.
Args:
cfg (dict): Configuration dictionary. It should contain at least the key "type".
module_type (MindFormerModuleType, optional): Module type name of MindFormers.
Default: ``MindFormerModuleType.TOOLS``.
default_args (dict, optional): Default initialization arguments. Default: ``None``.
Returns:
An instance of the class.
Raises:
TypeError: `cfg` must be a configuration.
KeyError: `cfg` or `default_args` must contain the key "type".
TypeError: `default_args` must be a dictionary or ``None``.
ValueError: Can't find class `class_name` of type `module_type` in the registry.
"""
if module_type == MindFormerModuleType.CONFIG:
model_type = cfg.pop('model_type')
obj_type = cls.get_cls(module_type, model_type)
elif module_type == MindFormerModuleType.MODELS:
architectures = cfg.pop('architectures')
if isinstance(architectures, list):
obj_type = architectures[0]
elif isinstance(architectures, str):
obj_type = architectures
else:
raise ValueError("The type of model_config.architectures should be str or list of str.")
else:
obj_type = cfg.pop('type')
return obj_type
[文档] @classmethod
def get_instance_from_cfg(cls, cfg, module_type=MindFormerModuleType.TOOLS, default_args=None):
"""
Get instances of the class in the registry via configuration.
Args:
cfg (dict): Configuration dictionary. It should contain at least the key "type".
module_type (MindFormerModuleType, optional): Module type name of MindFormers.
Default: ``MindFormerModuleType.TOOLS``.
default_args (dict, optional): Default initialization arguments. Default: ``None``.
Returns:
An instance of the class.
Raises:
TypeError: `cfg` must be a configuration.
KeyError: `cfg` or `default_args` must contain the key "type".
TypeError: `default_args` must be a dictionary or ``None``.
ValueError: Can't find class `class_name` of type `module_type` in the registry.
"""
if not isinstance(cfg, dict):
raise TypeError(
f"Cfg must be a Config, but got {type(cfg)}"
)
if 'auto_register' in cfg:
cls.auto_register(class_reference=cfg.pop('auto_register'), module_type=module_type)
use_legacy = get_context("use_legacy", True)
if use_legacy or module_type not in [MindFormerModuleType.CONFIG, MindFormerModuleType.MODELS]:
if 'type' not in cfg:
raise KeyError(
'`cfg` or `default_args` must contain the key "type",'
'but got {}\n{}'.format(cfg, default_args)
)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(f'default_args must be a dict or None, but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for k, v in default_args.items():
if k not in args:
args.setdefault(k, v)
else:
args[k] = v
if use_legacy:
obj_type = args.pop('type')
else:
obj_type = cls.get_instance_type_from_cfg(args, module_type)
if isinstance(obj_type, str):
obj_cls = cls.get_cls(module_type, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise ValueError(f"Can't find class type {type} class name {obj_type} in class registry")
try:
if not use_legacy and module_type == MindFormerModuleType.MODELS:
return obj_cls(default_args['config'])
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{obj_cls.__name__}: {e}')
[文档] @classmethod
def get_instance(cls, module_type=MindFormerModuleType.TOOLS, class_name=None, **kwargs):
"""
Gets an instance of the class in the registry.
Args:
module_type (MindFormerModuleType, optional): Module type name of MindFormers.
Default: ``MindFormerModuleType.TOOLS``.
class_name (str, optional): Class name. Default: ``None``.
kwargs (Any): Additional keyword arguments for constructing instances of the class.
Returns:
An instance of the class.
Raises:
ValueError: `class_name` cannot be ``None``.
ValueError: Can't find class `class_name` of type `module_type` in the registry.
"""
if class_name is None:
raise ValueError("Class name cannot be None.")
if isinstance(class_name, str):
obj_cls = cls.get_cls(module_type, class_name)
elif inspect.isclass(class_name):
obj_cls = class_name
else:
raise ValueError(f"Can't find class type {type} class name {class_name} in class registry.")
try:
return obj_cls(**kwargs)
except Exception as e:
raise type(e)(f'{obj_cls.__name__}: {e}')
@classmethod
def auto_register(cls, class_reference: str, module_type=MindFormerModuleType.TOOLS):
"""
Auto register function.
Args:
class_reference (str): The full name of the class to load.
module_type (MindFormerModuleType.TOOLS): module type.
"""
if not isinstance(class_reference, str):
raise ValueError(f"auto_map must be the type of string, but get {type(class_reference)} ."
f"Please fill in the following format: module_file.function_name, such as,"
f"llama_model.LlamaForCausalLM")
register_path = os.getenv("REGISTER_PATH", '')
if not register_path:
raise EnvironmentError("When configuring the 'auto_map' automatic registration function, "
"REGISTER_PATH must be specified. "
"It is recommended to complete this action"
"through the official startup script "
"'run_mindformer.py --register_path=module_file_path' "
"or use 'export REGISTER_PATH=module_file_path' to complete this action.")
if not os.path.realpath(register_path):
raise EnvironmentError(f"REGISTER_PATH must be real path, but get {register_path}, "
f"please specify the correct directory path.")
register_path = os.path.realpath(os.getenv("REGISTER_PATH"))
module_class = get_class_from_dynamic_module(
class_reference=class_reference, pretrained_model_name_or_path=register_path)
_ = cls.register_cls(module_class, module_type=module_type, legacy=get_legacy())
@classmethod
def _add_class_name_prefix(cls, module_type, class_name, legacy=True):
if not legacy and module_type in [MindFormerModuleType.MODELS, MindFormerModuleType.CONFIG]:
class_name = NEW_CLASS_PREFIX + class_name
return class_name