Source code for mindformers.tools.register.config

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Transformer-Config dict parse module """

import argparse
from argparse import Action
import copy
import os
from collections import OrderedDict
from typing import Union

import yaml
from mindformers.tools.check_rules import check_yaml_depth_before_loading
from .template import ConfigTemplate

BASE_CONFIG = 'base_config'


class DictConfig(dict):
    """config"""
    def __init__(self, **kwargs):
        super(DictConfig, self).__init__()
        self.update(kwargs)

    def __getattr__(self, key):
        if key not in self:
            return None
        return self[key]

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        del self[key]

    def __deepcopy__(self, memo=None):
        """Deep copy operation on arbitrary MindFormerConfig objects.

        Args:
            memo (dict) : Objects that already copied.
        Returns:
            MindFormerConfig : The deep copy of the given MindFormerConfig object.
        """
        config = self.__class__()
        for key in self.keys():
            config.__setattr__(copy.deepcopy(key, memo),
                               copy.deepcopy(self.__getattr__(key), memo))
        return config

    def to_dict(self):
        """
        for yaml dump,
        transform from Config to a strict dict class
        """
        return_dict = {}
        for key, val in self.items():
            if isinstance(val, self.__class__):
                val = val.to_dict()
            return_dict[key] = val
        return return_dict


[docs]class MindFormerConfig(DictConfig): """ A class for configuration that inherits from Python's dict class. Can parse configuration parameters from yaml files or dict instances. Args: args (Any): Extensible parameter list, either a yaml configuration file path or a configuration dictionary. kwargs (Any): Extensible parameter dictionary, either a yaml configuration file path or a configuration dictionary. Returns: An instance of the class. Examples: >>> from mindformers.tools import MindFormerConfig >>> >>> # test.yaml: >>> # a:1 >>> cfg = MindFormerConfig('./test.yaml') >>> cfg.a 1 >>> cfg = MindFormerConfig(**dict(a=1, b=dict(c=[0,1]))) >>> cfg.b {'c': [0, 1]} """ def __init__(self, *args, **kwargs): super(MindFormerConfig, self).__init__() cfg_dict = {} # load from file load_from_file = False for arg in args: if isinstance(arg, str): if arg.endswith('yaml') or arg.endswith('yml'): raw_dict = MindFormerConfig._file2dict(arg) cfg_dict.update(raw_dict) load_from_file = True # load dictionary configs if kwargs is not None: cfg_dict.update(kwargs) if load_from_file: ConfigTemplate.apply_template(cfg_dict) MindFormerConfig._dict2config(self, cfg_dict)
[docs] def merge_from_dict(self, options): """ Merge options into config. Args: options (dict): Configuration options that need to be merged. Examples: >>> from mindformers.tools import MindFormerConfig >>> >>> options = {'model.arch': 'LlamaForCausalLM'} >>> cfg = MindFormerConfig(**dict(model=dict(model_config=dict(type='LlamaConfig')))) >>> cfg.merge_from_dict(options) >>> print(cfg) {'model': {'model_config': {'type': 'LlamaConfig'}, 'arch': 'LlamaForCausalLM'}} """ option_cfg_dict = {} for full_key, value in options.items(): d = option_cfg_dict key_list = full_key.split('.') for sub_key in key_list[:-1]: d.setdefault(sub_key, MindFormerConfig()) d = d.get(sub_key) sub_key = key_list[-1] d[sub_key] = value merge_dict = MindFormerConfig._merge_a_into_b(option_cfg_dict, self) MindFormerConfig._dict2config(self, merge_dict)
@staticmethod def _merge_a_into_b(a, b): """Merge dict ``a`` into dict ``b`` Values in ``a`` will overwrite ``b`` Args: a (dict) : The source dict to be merged into b. b (dict) : The origin dict to be fetch keys from ``a``. Returns: dict: The modified dict of ``b`` using ``a``. """ b = b.copy() if b is not None else a.copy() for k, v in a.items(): if isinstance(v, dict) and k in b: b[k] = MindFormerConfig._merge_a_into_b(v, b[k]) else: b[k] = v return b @staticmethod def _file2dict(filename=None): """Convert config file to dictionary. Args: filename (str) : config file. """ if filename is None: raise NameError('This {} cannot be empty.'.format(filename)) filepath = os.path.realpath(filename) with open(filepath, encoding='utf-8') as fp: check_yaml_depth_before_loading(fp) fp.seek(0) cfg_dict = yaml.safe_load(fp) cfg_dict = OrderedDict(sorted(cfg_dict.items())) # Load base config file. if BASE_CONFIG in cfg_dict: cfg_dir = os.path.dirname(filename) base_filenames = cfg_dict.pop(BASE_CONFIG) base_filenames = base_filenames if isinstance( base_filenames, list) else [base_filenames] cfg_dict_list = list() for base_filename in base_filenames: cfg_dict_item = MindFormerConfig._file2dict( os.path.join(cfg_dir, base_filename)) cfg_dict_list.append(cfg_dict_item) base_cfg_dict = dict() for cfg in cfg_dict_list: base_cfg_dict.update(cfg) # Merge config base_cfg_dict = MindFormerConfig._merge_a_into_b(cfg_dict, base_cfg_dict) cfg_dict = base_cfg_dict return cfg_dict @staticmethod def _dict2config(config, dic): """Convert dictionary to config. Args: config : Config object dic (dict) : dictionary Returns: Exceptions: """ if isinstance(dic, dict): for key, value in dic.items(): if isinstance(value, dict): sub_config = MindFormerConfig() dict.__setitem__(config, key, sub_config) MindFormerConfig._dict2config(sub_config, value) else: config[key] = dic[key] def get_value(self, levels: Union[str, list], default=None): """Get the attribute according to levels, if not exist return default. Args: levels : the level to be accessed default : None Returns: default or value of the key to be accessed Examples: >>> config = MindFormerConfig(**{'context': {'mode': 'GRAPH_MODE'}, 'parallel': {}}) >>> config.get_value(['context', 'mode']) >>> 'GRAPH_MODE' >>> config.get_value(['context', 'mode'], 'DEFAULT_MODE') >>> 'GRAPH_MODE' >>> config.get_value(['context', 'fake_mode']) >>> None >>> config.get_value(['context', 'fake_mode'], 'DEFAULT_MODE') >>> 'DEFAULT_MODE' >>> config.get_value('context.mode', 'DEFAULT_MODE') >>> 'GRAPH_MODE' >>> config.get_value('context.fake_mode', 'DEFAULT_MODE') >>> 'DEFAULT_MODE' """ if not levels: return default if isinstance(levels, str): levels = levels.split('.') if len(levels) == 1: config = self or {} return config.get(str(levels[-1]), default) if getattr(self, levels[0]): return getattr(self, levels[0]).get_value( levels[1:], default ) return default def set_value(self, levels: Union[list, str], value): """set the attribute according to levels. Args: levels : the level to be accessed value : The value to be set Returns: None Examples: >>> config = MindFormerConfig(**{'context': {'mode': 0}, 'parallel': {}, 'test': None}) >>> config.set_value('context.mode', 1) >>> config = {'context': {'mode': 1}, 'parallel': {}, 'test': None}) >>> config.set_value(['context', 'device_id'], 2) >>> config = {'context': {'mode': 1, device_id: 2}, 'parallel': {}, 'test': None}) >>> config.set_value('parallel', {'hello', 'mf'}) >>> config = {'context': {'mode': 1, device_id: 2}, 'parallel': {'hello', 'mf'}, 'test': None}) >>> config.set_value('test.model', 1) >>> config = {'context': {'mode': 1, device_id: 2}, 'parallel': {'hello', 'mf'}, 'test': {'model': 1}}}) >>> config.set_value('test', {'data', 8) >>> config = {'context': {'mode': 1, device_id: 2}, 'parallel': {'hello', 'mf'}, 'test': {'data': 8}}}) """ if levels: if isinstance(levels, str): levels = levels.split('.') if len(levels) == 1: if isinstance(self, MindFormerConfig): setattr(self, levels[-1], value) return config = getattr(self, levels[0]) or MindFormerConfig() setattr(self, levels[0], config) self.get(levels[0]).set_value(levels[1:], value)
class ActionDict(Action): """ Argparse action to split an option into KEY=VALUE from on the first = and append to dictionary. List options can be passed as comma separated values. i.e. 'KEY=Val1,Val2,Val3' or with explicit brackets i.e. 'KEY=[Val1,Val2,Val3]'. """ @staticmethod def _parse_int_float_bool(val): """convert string val to int or float or bool or do nothing.""" try: return int(val) except ValueError: pass try: return float(val) except ValueError: pass if val.upper() in ['TRUE', 'FALSE']: return val.upper == 'TRUE' return val @staticmethod def find_next_comma(val_str): """find the position of next comma in the string. note: '(' and ')' or '[' and']' must appear in pairs or not exist. """ if val_str.count('(') != val_str.count(')') or \ (val_str.count('[') != val_str.count(']')): raise ValueError("( and ) or [ and ] must appear in pairs or not exist.") end = len(val_str) for idx, char in enumerate(val_str): pre = val_str[:idx] if ((char == ',') and (pre.count('(') == pre.count(')')) and (pre.count('[') == pre.count(']'))): end = idx break return end @staticmethod def _parse_value_iter(val): """Convert string format as list or tuple to python list object or tuple object. Args: val (str) : Value String Returns: list or tuple Examples: >>> ActionDict._parse_value_iter('1,2,3') [1,2,3] >>> ActionDict._parse_value_iter('[1,2,3]') [1,2,3] >>> ActionDict._parse_value_iter('(1,2,3)') (1,2,3) >>> ActionDict._parse_value_iter('[1,[1,2],(1,2,3)') [1, [1, 2], (1, 2, 3)] """ # strip ' and " and delete whitespace val = val.strip('\'\"').replace(" ", "") is_tuple = False if val.startswith('(') and val.endswith(')'): is_tuple = True # remove start '(' and end ')' val = val[1:-1] elif val.startswith('[') and val.endswith(']'): # remove start '[' and end ']' val = val[1:-1] elif ',' not in val: return ActionDict._parse_int_float_bool(val) values = [] len_of_val = len(val) while len_of_val > 0: comma_idx = ActionDict.find_next_comma(val) ele = ActionDict._parse_value_iter(val[:comma_idx]) values.append(ele) val = val[comma_idx + 1:] len_of_val = len(val) if is_tuple: return tuple(values) return values def __call__(self, parser, namespace, values, option_string=None): options = {} for key_value in values: key, value = key_value.split('=', maxsplit=1) options[key] = self._parse_value_iter(value) setattr(namespace, self.dest, options) def ordered_yaml_dump(data, stream=None, yaml_dumper=yaml.SafeDumper, object_pairs_hook=OrderedDict, **kwargs): """Dump Dict to Yaml File in Orderedly.""" class OrderedDumper(yaml_dumper): pass def _dict_representer(dumper, data): return dumper.represent_mapping( yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items()) OrderedDumper.add_representer(object_pairs_hook, _dict_representer) return yaml.dump(data, stream, OrderedDumper, **kwargs) def parse_args(): """ Parse arguments from `yaml or yml` config file. Returns: object: argparse object. """ parser = argparse.ArgumentParser("Transformer Config.") parser.add_argument('-c', '--config', type=str, default="", help='Enter the path of the model config file.') return parser.parse_args()