# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
For text generation
"""
import os
import json
from typing import Optional, List, Union, Dict
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.communication.management import init
from mindspore.common.initializer import Zero
from mindspore._c_expression import swap_cache
from mindformers import models, MindFormerRegister, MindFormerModuleType
from mindformers import build_context, build_parallel_config, GenerationConfig
from mindformers import AutoModel, AutoConfig, AutoTokenizer
from mindformers.core.context import is_legacy_model
from mindformers.models.utils import convert_mstype, str_to_ms_type
from mindformers.utils import contains_safetensors_files
from mindformers.tools.logger import logger
from mindformers.tools.register.config import MindFormerConfig
from mindformers.trainer.utils import transform_and_load_checkpoint
from mindformers.tools.hub.dynamic_module_utils import get_class_from_dynamic_module
from mindformers.generation.parallel_decoding import parallel_decoding_control
from mindformers.version_control import check_delay_init_valid, need_nz
from mindformers.models import build_processor, PretrainedConfig
from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert
__all__ = ["ModelRunner"]
def register_auto_class(config, pretrained_model_name_or_path, class_type, use_fast=True):
"""convert to auto class"""
if config.model.model_config.auto_map:
class_auto = config["model"]["model_config"]["auto_map"]
if class_type == "AutoConfig" and \
config.model.model_config.type not in MindFormerRegister.registry[MindFormerModuleType.CONFIG]:
class_ref = class_auto[class_type]
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path)
MindFormerRegister.register_cls(config_class, module_type=MindFormerModuleType.CONFIG)
if class_type == "AutoTokenizer" and \
config.processor.tokenizer.type not in MindFormerRegister.registry[MindFormerModuleType.TOKENIZER]:
if use_fast and class_auto[class_type][1] is not None:
class_ref = class_auto[class_type][1]
else:
class_ref = class_auto[class_type][0]
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path)
MindFormerRegister.register_cls(tokenizer_class, module_type=MindFormerModuleType.TOKENIZER)
if class_type == "AutoModel" and \
config.model.arch.type not in MindFormerRegister.registry[MindFormerModuleType.MODELS]:
class_ref = class_auto[class_type]
model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path)
MindFormerRegister.register_cls(model_class, module_type=MindFormerModuleType.MODELS)
if class_type == "AutoProcessor" and \
config.model.arch.type not in MindFormerRegister.registry[MindFormerModuleType.PROCESSOR]:
class_ref = class_auto[class_type]
processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path)
MindFormerRegister.register_cls(processor_class, module_type=MindFormerModuleType.PROCESSOR)
def is_multi_modal_model(config):
def count_type_num(model_config):
num = 0
for k, v in model_config.items():
if k == "type":
num += 1
if isinstance(v, dict):
num += count_type_num(v)
return num
return count_type_num(config.model.model_config) > 1
def get_model(model_name_or_path: str,
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
**kwargs):
"""
get_model API, supports MF to be a backend of MindIEServer.
Args:
model_name_or_path (str):
A path to a *directory* containing vocabulary files() required by the tokenizer.
revision (`str`, *optional*, defaults to `"None"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id.
trust_remote_code (`bool`, *optional*, defaults to `True`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments for AutoTokenizer.from_pretrained.
Returns:
A Tokenizer object and others.
"""
if not os.path.exists(model_name_or_path) or not os.path.isdir(model_name_or_path):
raise ValueError(f"{model_name_or_path} does not exist or is not a directory.")
logger.debug(f"model_name_or_path is {model_name_or_path}")
config_path = _get_model_config(model_name_or_path)
config = MindFormerConfig(config_path)
model_type = config.model.arch.type
logger.info(f"The model type is: {model_type}")
register_auto_class(config, model_name_or_path, class_type="AutoTokenizer")
if is_multi_modal_model(config):
processor = build_processor(config.processor)
return processor, processor
use_fast = kwargs.get("use_fast", True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast)
input_builder = InputBuilder(tokenizer)
return tokenizer, input_builder
[docs]class ModelRunner:
"""
ModelRunner API, supports MindFormers to be a backend of MindIEServer.
Args:
model_path (str):
The model config path contains model config file and tokenizer file.
npu_mem_size (int):
Npu memory size used for kv-cache.
cpu_mem_size (int):
Cpu memory size used for kv-cache.
block_size (int):
Block size used for kv-cache.
rank_id (int, optional):
Rank id used for infer. Default: ``0``.
world_size (int, optional):
Rank size used for infer. Default: ``1``.
npu_device_ids (list[int], optional):
Get npu_device_ids from MindIE config. Default: ``None``.
plugin_params (str, optional):
A JSON string that contains additional plugin parameters. Default: ``None``.
Returns:
A MindIERunner object.
Examples:
>>> from mindformers import ModelRunner
>>> model_path = /path/to/model/ # contains model config file and tokenizer file.
>>> npu_mem_size = 3
>>> cpu_mem_size = 1
>>> block_size = 128
>>> rank_id = 0
>>> world_size = 1
>>> npu_device_ids = [0]
>>> model_runner = ModelRunner(model_path=model_path, npu_mem_size=npu_mem_size, cpu_mem_size=cpu_mem_size,
>>> block_size=block_size, rank_id=rank_id, world_size=world_size,
>>> npu_device_ids=npu_device_ids)
>>> type(model_runner)
<class 'mindformers.model_runner.MindIEModelRunner'>
"""
def __new__(cls, model_path, npu_mem_size, cpu_mem_size, block_size, rank_id=0, world_size=1,
npu_device_ids=None, plugin_params=None):
config_path = _get_model_config(model_path)
config = MindFormerConfig(config_path)
model_type = config.model.arch.type
logger.info(f"The model type is: {model_type}")
model_runner_cls = MindIEModelRunner
if model_type not in models.__all__:
try:
import importlib
model_runner_cls = importlib.import_module(model_type, ["MindIEModelRunner"]).MindIEModelRunner
except ImportError:
logger.info(f"import MindIEModelRunner from module {model_type} failed, "
f"and will use the default one defined in mindformers.")
model_runner = model_runner_cls(model_path, config_path, npu_mem_size, cpu_mem_size,
block_size, rank_id, world_size, npu_device_ids, plugin_params)
return model_runner
class MindIEModelRunner:
"""
Implementation of ModelRunner.
Args:
model_path(str):
The model config path contains model config file and tokenizer file.
experiment_mode (bool):
Is experiment model.
model_config (PretrainedConfig):
Model config.
npu_mem_size (int):
Npu memory size used for kv-cache.
cpu_mem_size (int):
Cpu memory size used for kv-cache.
block_size (int):
Block size used for kv-cache.
rank_id (int):
Rank id used for infer.
world_size (int):
Rank size used for infer.
npu_device_ids (list[int]):
Get npu_device_ids from MindIE config.
plugin_params (str):
A JSON string that contains additional plugin parameters.
"""
def __init__(self, model_path, config_path, npu_mem_size, cpu_mem_size, block_size, rank_id=0,
world_size=1, npu_device_ids=None, plugin_params=None):
if plugin_params is not None and not isinstance(plugin_params, str):
raise ValueError("plugin params should be str type!")
self.dynamic_kv_cache_whitelist = ["ParallelLlamaForCausalLM", "InferenceDeepseekV3ForCausalLM"]
self.config = MindFormerConfig(config_path)
self.warmup_step = 2
self.is_multi_modal_model = is_multi_modal_model(self.config)
# register to Auto Class
register_auto_class(self.config, model_path, class_type="AutoConfig")
register_auto_class(self.config, model_path, class_type="AutoTokenizer")
register_auto_class(self.config, model_path, class_type="AutoModel")
# parallel predict with dynamic cluster.
if world_size > 1:
self.config.use_parallel = True
os.environ['MS_WORKER_NUM'] = str(world_size)
os.environ['MS_ROLE'] = 'MS_WORKER'
os.environ['MS_NODE_ID'] = str(rank_id)
ms.set_device("Ascend", npu_device_ids[0])
if rank_id == 0 and os.fork() == 0:
os.environ['MS_ROLE'] = 'MS_SCHED'
init()
if self.config.use_parallel:
build_parallel_config(self.config)
self.model_config = AutoConfig.from_pretrained(config_path, parallel_config=self.config.parallel_config)
else:
self.model_config = AutoConfig.from_pretrained(config_path)
setattr(self.model_config, 'npu_mem_size', npu_mem_size)
if self.config.moe_config:
self.model_config.moe_config = self.config.moe_config
self.update_model_config(plugin_params)
if not self.config.use_parallel and npu_device_ids:
if len(npu_device_ids) != 1:
raise ValueError("npu_device_ids should only contain one device_id")
self.config.context.device_id = npu_device_ids[0]
build_context(self.config)
logger.info(f"Build context finished.")
self.use_legacy = is_legacy_model()
if self.is_multi_modal_model:
if isinstance(self.model_config.llm_model, PretrainedConfig):
llm_config = self.model_config.llm_model
else:
llm_config = self.model_config.llm_model.model_config
self.update_llm_config(llm_config, world_size, npu_mem_size, cpu_mem_size, block_size)
self.processor = build_processor(self.config.processor)
# adapt to mindie-llm
self.model_config.max_position_embedding = llm_config.max_position_embedding
else:
self.update_llm_config(self.model_config, world_size, npu_mem_size, cpu_mem_size, block_size)
self.generation_config = GenerationConfig.from_model_config(self.model_config)
# build tokenizer
if self.is_multi_modal_model:
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False, use_fast=True)
logger.info(f"Build tokenizer finished.")
# build model
network_delay_inited = False
if check_delay_init_valid():
from mindspore.nn.utils import no_init_parameters
with no_init_parameters():
self.model = AutoModel.from_config(self.model_config)
network_delay_inited = True
logger.info("Parameters are not initialized during model initialization.")
else:
self.model = AutoModel.from_config(self.model_config)
if npu_mem_size == -1 and str(type(self.model).__name__) not in self.dynamic_kv_cache_whitelist:
raise ValueError("npu_mem_size=-1 only support in parallel mode")
logger.info(f"Build model finished.")
self.load_checkpoint(network_delay_inited)
if not self.use_legacy or self.model_config.is_dynamic:
self.model.set_dynamic_inputs()
cpu_kv_shape = (self.cpu_num_blocks, block_size, self.num_kv_heads, self.head_size)
self.key_host = [ms.Parameter(ms.Tensor(shape=cpu_kv_shape, dtype=self.dtype, init=Zero()),
name=f"key_host_{i}", requires_grad=False).init_data() \
for i in range(self.num_layers)]
self.value_host = [ms.Parameter(ms.Tensor(shape=cpu_kv_shape, dtype=self.dtype, init=Zero()),
name=f"value_host_{i}", requires_grad=False).init_data() \
for i in range(self.num_layers)]
def load_checkpoint(self, network_delay_inited):
"""load checkpoint into model"""
ms_model = ms.Model(self.model)
batch_size = self.model_config.batch_size
seq_length = self.model_config.seq_length
input_ids = np.ones(shape=tuple([batch_size, seq_length]))
if self.use_legacy:
inputs = self.model.prepare_inputs_for_predict_layout(input_ids)
else:
inputs = None
self.config.load_checkpoint = get_load_path_after_hf_convert(self.config, self.model)
if self.config.load_checkpoint:
transform_and_load_checkpoint(self.config, ms_model, self.model, inputs, do_predict=True)
else:
logger.warning("No checkpoint loaded. Network will be inited randomly.")
if network_delay_inited:
self.model.init_parameters_data()
logger.info(f"Load checkpoints finished.")
def update_model_config(self, plugin_params):
"""update model config"""
self.model_config.parallel_decoding_params = None
default_plugin_configs = {'plugin_type': None}
if plugin_params == default_plugin_configs:
plugin_params = None
if plugin_params:
if not isinstance(plugin_params, dict):
plugin_params = json.loads(plugin_params)
plugin_params['parallel_decoding'] = plugin_params['plugin_type']
self.model_config.parallel_decoding_params = plugin_params
self.model_config.checkpoint_name_or_path = None
self.model_config.checkpoint_path = self.config.load_checkpoint
def update_llm_config(self, config, world_size, npu_mem_size, cpu_mem_size, block_size):
"""update llm model config"""
if self.use_legacy:
self.num_layers = config.num_layers
self.num_kv_heads = config.num_heads if config.n_kv_heads is None \
else config.n_kv_heads
# check the divisibility in model initialization.
self.num_kv_heads = self.num_kv_heads // world_size
self.head_size = config.hidden_size // config.num_heads
else:
self.num_layers = config.num_hidden_layers
self.num_kv_heads = config.num_attention_heads if config.num_key_value_heads is None \
else config.num_key_value_heads
self.num_kv_heads = self.num_kv_heads // world_size
self.head_size = config.hidden_size // config.num_attention_heads
kvcache_dtype = config.compute_dtype
if hasattr(self.model_config, "quantization_config") and \
self.model_config.quantization_config.kvcache_dtype in str_to_ms_type:
kvcache_dtype = self.model_config.quantization_config.kvcache_dtype
self.dtype = convert_mstype(kvcache_dtype)
kvcache_bytes = ms.Tensor(0, dtype=self.dtype).itemsize
total_head_size = self.num_kv_heads * self.head_size
if need_nz():
total_head_size = -(total_head_size // -16) * 16
self.npu_num_blocks = (npu_mem_size * 1024 * 1024 * 1024) // \
(block_size * total_head_size * kvcache_bytes * 2 * self.num_layers)
self.cpu_num_blocks = (cpu_mem_size * 1024 * 1024 * 1024) // \
(block_size * total_head_size * kvcache_bytes * 2 * self.num_layers)
config.block_size = block_size
config.num_blocks = self.npu_num_blocks
if not hasattr(config, "max_position_embedding") or not config.max_position_embedding:
config.max_position_embedding = config.seq_length
def forward(self, input_ids: [Union[List[int], List[List[int]]]],
valid_length_each_example: List[int],
block_tables: Optional[Tensor] = None,
slot_mapping: Optional[Tensor] = None,
prefill: bool = True,
position_ids: Optional[Tensor] = None,
spec_mask: Optional[Tensor] = None,
q_seq_lens: Optional[Tensor] = None,
adapter_ids: Optional[List[str]] = None,
prefill_head_indices: Optional[Tensor] = None,
key_cache: Optional[List[Tensor]] = None,
value_cache: Optional[List[Tensor]] = None):
"""
Call self.model.infer() or self.model.forward() to do infer and return logits on next position, \
can choose do prefill or decode predict.
Args:
input_ids (List(List(int))):
Input ids after padding.
valid_length_each_example (List(int)):
Valid input length except padding.
block_tables (Tensor):
Params for page attention
slot_mapping (Tensor):
Params for page attention
prefill (bool):
Whether to do prefill predict or decode predict
position_ids (Tensor):
Params for position encoding
spec_mask (Tensor):
Params for page attention
q_seq_lens (Tensor):
Params for page attention
adapter_ids (List(str)):
Params for SLora request
prefill_head_indices (Tensor):
Params for pre gather
key_cache (List(Tensor), optional):
Params for key_cache, a group of tensors used for kvcache. Default: None.
value_cache (List(Tensor), optional):
Params for value_cache, a group of tensors used for kvcache. Default: None.
Returns:
logits (Tensor)
"""
is_warm_up = self.warmup_step > 0
valid_length_each_example = np.array(valid_length_each_example)
model_args = {"mindie_warm_up": is_warm_up}
if self.is_multi_modal_model and not is_warm_up:
if prefill:
input_ids, decode_args = self.processor.decode_input_ids(input_ids, valid_length_each_example)
decode_args.pop("position_ids", None)
model_args.update(decode_args)
if self.use_legacy:
res, current_idx = self.model.forward(input_ids=input_ids,
valid_length_each_example=valid_length_each_example,
block_tables=block_tables,
slot_mapping=slot_mapping,
prefill=prefill,
use_past=True,
position_ids=position_ids,
spec_mask=spec_mask,
q_seq_lens=q_seq_lens,
adapter_ids=adapter_ids,
prefill_head_indices=prefill_head_indices,
key_cache=key_cache,
value_cache=value_cache,
**model_args)
else:
res, current_idx = self.model.forward_mcore(input_ids=input_ids,
valid_length_each_example=valid_length_each_example,
block_tables=block_tables,
slot_mapping=slot_mapping,
prefill=prefill,
position_ids=position_ids,
spec_mask=spec_mask,
q_seq_lens=q_seq_lens,
adapter_ids=adapter_ids,
prefill_head_indices=prefill_head_indices,
key_cache=key_cache,
value_cache=value_cache,
**model_args)
logits = res[0] if isinstance(res, tuple) else res
if hasattr(self, 'model_config') and parallel_decoding_control(self.model_config):
return logits
if self.use_legacy and prefill and logits.shape[0] > len(current_idx):
logits = logits[Tensor(current_idx)]
if self.warmup_step > 0:
self.warmup_step -= 1
return logits
def swap(self, block_tables, swap_type):
"""
Swap key/value cache between host and device, to support multi-batch and long-sequence inference.
Args:
block_tables:
A 2-D array contains src and dst blocks to swap.
swap_type:
A bool value indicating the data direction: "True" for device-to-host, and "False" for host-to-device.
"""
for i in range(self.num_layers):
key_cache, value_cache = self.model.kvcache(i)
swap_cache(self.key_host[i], key_cache, ms.Tensor(block_tables), swap_type)
swap_cache(self.value_host[i], value_cache, ms.Tensor(block_tables), swap_type)
def generate_position_ids(self, input_ids):
if not self.is_multi_modal_model or self.warmup_step > 0:
return range(len(input_ids))
return self.processor.decode_position_ids_from_input_ids(input_ids)
def _get_model_config(model_path):
"""
Get model config from the config file.
Args:
model_path: path of model config file.
Returns:
config_path.
"""
if os.path.isdir(model_path):
yaml_list = [file
for file in os.listdir(model_path)
if file.endswith(".yaml")]
if yaml_list:
yaml_path = os.path.join(model_path, yaml_list[0])
else:
raise FileNotFoundError(f"There is no yaml file for model config in {model_path}.")
else:
raise ValueError(f"The path {model_path} is not exist.")
return yaml_path
class InputBuilder:
"""
Implementation of InputBuilder.
Args:
tokenizer (PreTrainedTokenizer):
A tokenizer for text processing.
chat_template (str):
A Jinja template to use for this conversion.
system_role_name (str):
The name of system role.
user_role_name (str):
The name of user role.
max_length (int):
The max length of input tokens.
"""
def __init__(self, tokenizer, chat_template="", system_role_name="system", user_role_name="user", max_length=2048):
self.tokenizer = tokenizer
self.system_role_name = system_role_name
self.user_role_name = user_role_name
if chat_template:
self.tokenizer.chat_template = chat_template
self.max_length = max_length
self.rank = 0
self.adapt_to_max_length = False
def make_context(self, rank: int, conversation: List[Dict[str, str]], add_generation_prompt: bool = True,
adapt_to_max_length: bool = False, **kwargs):
"""
Make a conversation tokens. Adapt interface of mindie-llm.
Args:
rank (int):
The rank id.
conversation (List[Dict[str, str]]):
A conversation object or list of dicts.
add_generation_prompt (bool, *optional*):
Whether to end the prompt with the token(s) that indicate the start of an assistant message.
adapt_to_max_length (bool, *optional*):
Where input tokens should less max_length.
Returns:
context_tokens
"""
self.rank = rank
self.adapt_to_max_length = adapt_to_max_length
context_tokens = self._apply_chat_template(conversation, add_generation_prompt=add_generation_prompt,
**kwargs)
return context_tokens
def _apply_chat_template(self, conversation: List[Dict[str, str]], **kwargs):
"""
Converts a Conversation to a list of token ids.
Args:
conversation (List[Dict[str, str]]):
A conversation object or list of dicts.
Returns:
input_ids
"""
if not hasattr(self.tokenizer, "apply_chat_template"):
raise RuntimeError("The tokenizer dose not implement apply_chat_template function.")
if not self.tokenizer.chat_template:
raise RuntimeError("The model does not appear to be a chat model because it is not configured with a "
"`chat_template`.")
input_ids = self.tokenizer.apply_chat_template(conversation, **kwargs)
return input_ids
def _load_distributed_safetensors(model, strategy_path, load_safetensors):
"""Load distributed safetensors"""
ms.load_distributed_checkpoint(
network=model,
predict_strategy=strategy_path,
unified_safetensors_dir=load_safetensors,
format='safetensors'
)
def _load_safetensors(model, load_safetensors):
"""Load single safetensors"""
sf_list = [sf for sf in os.listdir(load_safetensors) if sf.endswith('.safetensors')]
if not sf_list:
raise FileNotFoundError(f"There are no safetensors files under the given path {load_safetensors}.")
for sf in sf_list:
ms.load_checkpoint(
ckpt_file_name=os.path.join(load_safetensors, sf),
net=model,
format='safetensors'
)
def _check_valid_safetensors_path(path):
"""Check whether the safetensors path is valid"""
if not isinstance(path, str) or isinstance(path, os.PathLike):
raise ValueError(f"path must be a str, but got {path} as type {type(path)}.")
if not os.path.exists(path):
raise ValueError(f"path does not exist.")
if contains_safetensors_files(path):
return
raise ValueError(f"load_checkpoint is not a valid path for safetensors.")