mindformers.tools.register.register 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file was refer to project
# https://gitee.com/mindspore/vision/blob/master/mindvision/engine/class_factory.py
# ============================================================================
""" Class Register Module For MindFormers."""

import inspect
import os

from mindformers.tools.hub.dynamic_module_utils import get_class_from_dynamic_module
from mindformers.version_control import check_tft_valid


[文档]class MindFormerModuleType: """ Enumerated class of the MindFormers module type, which includes: .. list-table:: :widths: 25 25 :header-rows: 1 * - Enumeration Name - Value * - CALLBACK - 'CALLBACK' * - CONFIG - 'config' * - CONTEXT - 'context' * - DATASET - 'dataset' * - DATASET_LOADER - 'dataset_loader' * - DATASET_SAMPLER - 'dataset_sampler' * - DATA_HANDLER - 'data_handler' * - ENCODER - 'encoder' * - FEATURE_EXTRACTOR - 'feature_extractor' * - LOSS - 'loss' * - LR - 'lr' * - MASK_POLICY - 'mask_policy' * - METRIC - 'metric' * - MODELS - 'models' * - OPTIMIZER - 'optimizer' * - PIPELINE - 'pipeline' * - PROCESSOR - 'processor' * - TOKENIZER - 'tokenizer' * - TOOLS - 'tools' * - TRAINER - 'trainer' * - TRANSFORMS - 'transforms' * - WRAPPER - 'wrapper' Examples: >>> from mindformers.tools import MindFormerModuleType >>> >>> print(MindFormerModuleType.MODULES) modules """ def __init__(self): pass TRAINER = 'trainer' PIPELINE = 'pipeline' PROCESSOR = 'processor' TOKENIZER = 'tokenizer' DATASET = 'dataset' MASK_POLICY = 'mask_policy' DATASET_LOADER = 'dataset_loader' DATASET_SAMPLER = 'dataset_sampler' TRANSFORMS = 'transforms' ENCODER = 'encoder' MODELS = 'models' MODULES = 'modules' TRANSFORMER = 'transformer' BASE_LAYER = 'base_layer' CORE = 'core' HEAD = 'head' LOSS = 'loss' LR = 'lr' OPTIMIZER = 'optimizer' CONTEXT = 'context' CALLBACK = 'callback' WRAPPER = 'wrapper' METRIC = 'metric' CONFIG = 'config' TOOLS = 'tools' FEATURE_EXTRACTOR = 'feature_extractor' DATA_HANDLER = 'data_handler'
[文档]class MindFormerRegister: """ The registration interface for MindFormers, provides methods for registering and obtaining the interface. Examples: >>> from mindformers.tools import MindFormerModuleType, MindFormerRegister >>> >>> >>> # Using decorator to register the class >>> @MindFormerRegister.register(MindFormerModuleType.CONFIG) >>> class MyConfig: ... def __init__(self, param): ... self.param = param >>> >>> >>> # Using method to register the class >>> MindFormerRegister.register_cls(register_class=MyConfig, module_type=MindFormerRegister) >>> >>> print(MindFormerRegister.is_exist(module_type=MindFormerModuleType.CONFIG, class_name="MyConfig")) True >>> cls = MindFormerRegister.get_cls(module_type=MindFormerModuleType.CONFIG, class_name="MyConfig") >>> print(cls.__name__) MyConfig >>> instance = MindFormerRegister.get_instance_from_cfg(cfg={'type': 'MyConfig', 'param': 0}, ... module_type=MindFormerModuleType.CONFIG) >>> print(instance.__class__.__name__) MyConfig >>> print(instance.param) 0 >>> instance = MindFormerRegister.get_instance(module_type=MindFormerModuleType.CONFIG, ... class_name="MyConfig", ... param=0) >>> print(instance.__class__.__name__) MyConfig >>> print(instance.param) 0 """ def __init__(self): pass registry = {}
[文档] @classmethod def register(cls, module_type=MindFormerModuleType.TOOLS, alias=None): """ A decorator that registers the class in the registry. Args: module_type (MindFormerModuleType, optional): Module type name of MindFormers. Default: ``MindFormerModuleType.TOOLS``. alias (str, optional): Alias for the class. Default: ``None``. Returns: Wrapper, decorates the registered class. """ def wrapper(register_class): """Register-Class with wrapper function. Args: register_class : class need to register Returns: wrapper of register_class """ class_name = alias if alias is not None else register_class.__name__ if module_type not in cls.registry: cls.registry[module_type] = {class_name: register_class} else: cls.registry[module_type][class_name] = register_class return register_class return wrapper
[文档] @classmethod def register_cls(cls, register_class, module_type=MindFormerModuleType.TOOLS, alias=None): """ A method that registers a class into the registry. Args: register_class (type): The class that need to be registered. module_type (MindFormerModuleType, optional): Module type name of MindFormers. Default: ``MindFormerModuleType.TOOLS``. alias (str, optional): Alias for the class. Default: ``None``. Returns: Class, the registered class itself. """ class_name = alias if alias is not None else register_class.__name__ if module_type not in cls.registry: cls.registry[module_type] = {class_name: register_class} else: cls.registry[module_type][class_name] = register_class return register_class
[文档] @classmethod def is_exist(cls, module_type, class_name=None): """ Determines whether the given class name is in the current type group. If `class_name` is not given, determines if the given class name is in the current registered dictionary. Args: module_type (MindFormerModuleType): Module type name of MindFormers. class_name (str, optional): Class name. Default: ``None``. Returns: A boolean value, indicating whether it exists or not. """ if not class_name: return module_type in cls.registry registered = module_type in cls.registry and class_name in cls.registry.get(module_type) return registered
[文档] @classmethod def get_cls(cls, module_type, class_name=None): """ Get the class from the registry. Args: module_type (MindFormerModuleType): Module type name of MindFormers. class_name (str, optional): Class name. Default: ``None``. Returns: A registered class. Raises: ValueError: Can't find class `class_name` of type `module_type` in the registry. ValueError: Can't find type `module_type` in the registry. """ if not cls.is_exist(module_type, class_name): raise ValueError("Can't find class type {} class name {} \ in class registry".format(module_type, class_name)) if not class_name: raise ValueError( "Can't find class. class type = {}".format(class_name)) register_class = cls.registry.get(module_type).get(class_name) return register_class
[文档] @classmethod def get_instance_from_cfg(cls, cfg, module_type=MindFormerModuleType.TOOLS, default_args=None): """ Get instances of the class in the registry via configuration. Args: cfg (dict): Configuration dictionary. It should contain at least the key "type". module_type (MindFormerModuleType, optional): Module type name of MindFormers. Default: ``MindFormerModuleType.TOOLS``. default_args (dict, optional): Default initialization arguments. Default: ``None``. Returns: An instance of the class. Raises: TypeError: `cfg` must be a configuration. KeyError: `cfg` or `default_args` must contain the key "type". TypeError: `default_args` must be a dictionary or ``None``. ValueError: Can't find class `class_name` of type `module_type` in the registry. """ if not isinstance(cfg, dict): raise TypeError( "Cfg must be a Config, but got {}".format(type(cfg)) ) if 'auto_register' in cfg: cls.auto_register(class_reference=cfg.pop('auto_register'), module_type=module_type) if 'type' not in cfg: raise KeyError( '`cfg` or `default_args` must contain the key "type",' 'but got {}\n{}'.format(cfg, default_args) ) if not (isinstance(default_args, dict) or default_args is None): raise TypeError( 'default_args must be a dict or None' 'but got {}'.format(type(default_args))) args = cfg.copy() if default_args is not None: for k, v in default_args.items(): if k not in args: args.setdefault(k, v) else: args[k] = v obj_type = args.pop('type') if isinstance(obj_type, str): obj_cls = cls.get_cls(module_type, obj_type) elif inspect.isclass(obj_type): obj_cls = obj_type else: raise ValueError("Can't find class type {} class name {} \ in class registry".format(type, obj_type)) try: if module_type == MindFormerModuleType.OPTIMIZER: if check_tft_valid(): from mindspore.train.callback import TrainFaultTolerance obj_cls = TrainFaultTolerance.get_optimizer_wrapper(obj_cls) return obj_cls(**args) except Exception as e: raise type(e)('{}: {}'.format(obj_cls.__name__, e))
[文档] @classmethod def get_instance(cls, module_type=MindFormerModuleType.TOOLS, class_name=None, **kwargs): """ Gets an instance of the class in the registry. Args: module_type (MindFormerModuleType, optional): Module type name of MindFormers. Default: ``MindFormerModuleType.TOOLS``. class_name (str, optional): Class name. Default: ``None``. kwargs (Any): Additional keyword arguments for constructing instances of the class. Returns: An instance of the class. Raises: ValueError: `class_name` cannot be ``None``. ValueError: Can't find class `class_name` of type `module_type` in the registry. """ if class_name is None: raise ValueError("Class name cannot be None.") if isinstance(class_name, str): obj_cls = cls.get_cls(module_type, class_name) elif inspect.isclass(class_name): obj_cls = class_name else: raise ValueError("Can't find class type {} class name {} \ in class registry.".format(type, class_name)) try: return obj_cls(**kwargs) except Exception as e: raise type(e)('{}: {}'.format(obj_cls.__name__, e))
@classmethod def auto_register(cls, class_reference: str, module_type=MindFormerModuleType.TOOLS): """ Auto register function. Args: class_reference (str): The full name of the class to load. module_type (MindFormerModuleType.TOOLS): module type. """ if not isinstance(class_reference, str): raise ValueError(f"auto_map must be the type of string, but get {type(class_reference)} ." f"Please fill in the following format: module_file.function_name, such as," f"llama_model.LlamaForCausalLM") register_path = os.getenv("REGISTER_PATH", '') if not register_path: raise EnvironmentError("When configuring the 'auto_map' automatic registration function, " "REGISTER_PATH must be specified. " "It is recommended to complete this action" "through the official startup script " "'run_mindformer.py --register_path=module_file_path' " "or use 'export REGISTER_PATH=module_file_path' to complete this action.") if not os.path.realpath(register_path): raise EnvironmentError(f"REGISTER_PATH must be real path, but get {register_path}, " f"please specify the correct directory path.") register_path = os.path.realpath(os.getenv("REGISTER_PATH")) module_class = get_class_from_dynamic_module( class_reference=class_reference, pretrained_model_name_or_path=register_path) cls.register_cls(module_class, module_type=module_type)