# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network helper base class."""
from typing import Union, List
import numpy as np
from mindspore.nn import Cell
[docs]class NetworkHelper:
"""NetworkHelper for decoupling algorithm with network framework."""
[docs] def create_network(self):
"""
Create a network.
Returns:
Created network.
Examples:
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers.tools.register.config import MindFormerConfig
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> network = helper.create_network()
"""
raise NotImplementedError
[docs] def get_spec(self, name: str):
"""
Get network specific, such as batch_size, seq_length and so on.
Args:
name (str): Name of specific.
Returns:
Object as network specific.
Examples:
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers.tools.register.config import MindFormerConfig
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> helper.get_spec("batch_size")
1 (The output is related to the `mfconfig`, and the result here is just for example.)
"""
raise NotImplementedError
[docs] def create_tokenizer(self, **kwargs):
"""
Get network tokenizer.
Args:
kwargs (Dict): Extensible parameter for subclasses.
Returns:
Object as network tokenizer.
Examples:
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers.tools.register.config import MindFormerConfig
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> helper.create_tokenizer()
LlamaTokenizer(name_or_path='', vocab_size=32000, model_max_length=100000, added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, normalized=True, special=True),
1: AddedToken("<s>", rstrip=False, lstrip=False, normalized=True, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, normalized=True, special=True),
})
"""
raise NotImplementedError
[docs] def generate(self, network: Cell, input_ids: Union[np.ndarray, List[int], List[List[int]]],
max_new_tokens=None, **kwargs):
"""
Invoke `network` and generate tokens.
Args:
network (Cell): Network to generate tokens.
input_ids (numpy.ndarray): Input tokens for generate.
max_new_tokens (int): Max number of tokens to be generated, default 1.
kwargs (Dict): Extensible parameter for subclasses.
Returns:
A list as generated tokens.
Examples:
>>> import numpy as np
>>> from mindspore import context
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers import LlamaForCausalLM, LlamaConfig
>>> from mindformers.tools.register.config import MindFormerConfig
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> network = LlamaForCausalLM(LlamaConfig(**mfconfig.model.model_config))
>>> input_ids = np.array([[1, 10000]], dtype = np.int32)
>>> helper.generate(network, input_ids)
array([[ 1, 10000, 10001]], dtype=int32)
"""
raise NotImplementedError