Source code for mindformers.generation.text_generator

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
For text generation
"""
import copy
import time
from typing import Optional, List, Union, Dict

import numpy as np
import mindspore as ms
from mindspore import mint
from mindspore.ops import functional as F
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor

from mindformers.generation.beam_search import BeamSearchScorer
from mindformers.generation.generation_config import GenerationConfig
from mindformers.generation.logits_process import (LogitNormalization, LogitsProcessorList,
                                                   RepetitionPenaltyLogitsProcessor,
                                                   TemperatureLogitsWarper, TopKLogitsWarper,
                                                   TopPLogitsWarper, MinLengthLogitsProcessor,
                                                   MinNewTokensLengthLogitsProcessor)
from mindformers import version_control
from mindformers.models.tokenization_utils import PreTrainedTokenizer
from mindformers.generation.streamers import BaseStreamer
from mindformers.generation.utils import softmax_with_threads, topk, GenerateOutput, InferOutput
from mindformers.modules.block_tables import BlockTables
from mindformers.tools.logger import logger
from mindformers.tools.utils import is_pynative, get_context
from mindformers.tools.debug_info import DetailedLatency, Profiling
from mindformers.generation.parallel_decoding import parallel_decoding_control, parallel_decoding_process
from mindformers.generation.parallel_decoding_mcore import la_pre_process

__all__ = ["GenerationMixin"]


def get_valid_length_each_example(input_ids, pad_token_id):
    """get valid length and max length in a batch"""
    batch_size = input_ids.shape[0]
    valid_length_each_example = []
    for i in range(batch_size):
        # As the nonzero returns the index and we need length
        valid_length_each_example.append(
            np.max(np.argwhere(input_ids[i] != pad_token_id))
            + 1
        )
    valid_length_each_example = np.array(valid_length_each_example)
    logger.debug("Get the valid for each example is: %s", valid_length_each_example)
    max_length = np.max(valid_length_each_example)
    return valid_length_each_example, max_length


class GenerationMode:
    """
    Possible generation modes.
    """

    # Non-beam methods
    GREEDY_SEARCH = "greedy_search"
    SAMPLE = "sample"
    # Beam methods
    BEAM_SEARCH = "beam_search"


[docs]class GenerationMixin: """A class providing all functions for autoregressive text generation, used as a mixin with PreTrainedModel.""" def __init__(self): self.detailed_latency = DetailedLatency() self.profile = Profiling() self.block_mgr = None self.use_mint_op = version_control.use_mint_op() self.is_pynative = is_pynative() self.argmax = mint.argmax if self.use_mint_op else ms.ops.argmax self._pre_set_phase = None self._exec_add_flags = True self.gather = P.Gather() def _set_network_phase(self, phase): self._pre_set_phase = phase self._exec_add_flags = True def _set_block_mgr(self, batch_size, seq_length): """ Set model block table mgr function. """ if not self.block_mgr: self.block_mgr = BlockTables(self.config.num_blocks, self.config.block_size, seq_length) if self.block_mgr: self.block_mgr.init_cache_engine(batch_size) def _prepare_inputs_for_prefill_flatten(self, input_ids, batch_valid_length, slot_mapping, model_inputs): """prepare inputs ids for prefill flatten""" batch_valid_length_bs = batch_valid_length.shape[0] # [bs,] input_ids_list = [] for i in range(batch_valid_length_bs): context_len = batch_valid_length[i] input_ids_list.append(input_ids[i][:context_len]) input_ids = np.concatenate(input_ids_list, 0) input_ids = input_ids.reshape((1, -1)) slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1)) model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32)) model_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping) return model_inputs # pylint: disable=W0613 def prepare_inputs_for_generation(self, input_ids, **kwargs): """ prepare inputs for generation. A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()` """ model_inputs = {"input_ids": Tensor.from_numpy(input_ids.astype(np.int32))} if self.is_pynative: model_inputs = {} if self.config.is_dynamic and "origin_inputs" in kwargs and self.use_past: input_ids = kwargs["origin_inputs"] model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32)) else: if self.config.is_dynamic: prefill = kwargs.get("prefill") if prefill and "origin_inputs" in kwargs: origin_inputs = kwargs["origin_inputs"] batch_valid_length = kwargs.get("valid_length_each_example") slot_mapping = kwargs.get("slot_mapping") model_inputs = self._prepare_inputs_for_prefill_flatten(origin_inputs, batch_valid_length, slot_mapping, model_inputs) return model_inputs def add_flags_custom(self, is_first_iteration): """ Add customized attributes for specific cells in the model. If the model does not implement this method, this will add customized attributes for all cells in the model recursively. Args: is_first_iteration (bool): Network configuration information. Indicate whether current iteration is the first iteration in prediction. """ self.add_flags_recursive(is_first_iteration=is_first_iteration) def add_flags_custom_mcore(self, is_prefill): """ Add customized attributes for specific cells in the model. If the model does not implement this method, this will add customized attributes for all cells in the model recursively. Args: is_first_iteration (bool): Network configuration information. Indicate whether current iteration is the first iteration in prediction. """ self.add_flags_recursive(is_prefill=is_prefill) # pylint: disable=W0613 def update_model_kwargs_before_generate(self, input_ids, model_kwargs: dict): """ update model kwargs before generate. If your model needs to update model kwargs before generate, implement this method in your model, else do nothing. """ return def slice_incremental_inputs(self, model_inputs: dict, current_index): """used for non-first iterations, slice the inputs to length 1.""" input_ids = model_inputs.pop("input_ids") if isinstance(input_ids, Tensor): if input_ids.shape[-1] == 1: model_inputs["input_ids"] = input_ids return input_ids = input_ids.asnumpy() current_index_tmp = current_index - np.arange(input_ids.size, step=input_ids.shape[1]) arg = np.arange(input_ids.shape[0]) inputs_tmp = input_ids[arg, current_index_tmp].reshape(-1, 1) model_inputs["input_ids"] = Tensor.from_numpy(inputs_tmp.astype(np.int32)) def process_logits(self, logits, current_index=None, keep_all=False): """Process the logits""" logits = logits.reshape(-1, logits.shape[-1]) if not keep_all and current_index is not None: index = current_index.view(-1,) logits = P.Gather()(logits, index, 0) outputs = P.LogSoftmax(-1)(logits) outputs = F.tensor_pow(np.e, outputs) return outputs def get_logits_processor(self, generation_config: GenerationConfig, input_ids_seq_length: int, logits_processor: Optional[LogitsProcessorList]): """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] instances used to modify the scores of the language model head. """ # instantiate processors list processors = LogitsProcessorList() if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty=generation_config.repetition_penalty)) if ( generation_config.min_length is not None and generation_config.eos_token_id is not None and generation_config.min_length > 0 ): processors.append( MinLengthLogitsProcessor( generation_config.min_length, generation_config.eos_token_id, generation_config.pad_token_id ) ) if ( generation_config.min_new_tokens is not None and generation_config.eos_token_id is not None and generation_config.min_new_tokens > 0 ): processors.append( MinNewTokensLengthLogitsProcessor( input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id, generation_config.pad_token_id ) ) processors = self._merge_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) return processors def _merge_processor_list(self, default_list: LogitsProcessorList, custom_list: LogitsProcessorList): """merge custom processor list with default list.""" if not custom_list: return default_list for default in default_list: for custom in custom_list: if type(custom) is type(default): object_type = "logits processor" raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" f" `.generate()`, but it has already been created with the values {default}." f" {default} has been created by passing the corresponding arguments to generate or" f" by the model's config default values. If you just want to change the default values" f" of {object_type} consider passing them as arguments to `.generate()`" f" instead of using a custom {object_type}." ) default_list.extend(custom_list) return default_list def get_logits_warper(self, generation_config: GenerationConfig): """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ # instantiate warpers list warpers = LogitsProcessorList() # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) if not generation_config.do_sample: return warpers min_tokens_to_keep = 1 if generation_config.top_k is not None and generation_config.top_k > 0: warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None: warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) return warpers def _get_generation_mode(self, generation_config: GenerationConfig): """determine the generation mode by config""" if generation_config.num_beams == 1: if generation_config.do_sample: return GenerationMode.SAMPLE return GenerationMode.GREEDY_SEARCH return GenerationMode.BEAM_SEARCH def _prepare_model_inputs_for_decoder(self, input_ids, input_mask): """generate the inputs for the decoder""" batch_size = input_ids.shape[0] encoder_mask = Tensor(input_mask, mstype.float32) encoder_output = self.encoder_forward( Tensor(input_ids, mstype.int32), encoder_mask ) input_ids = np.zeros((batch_size, self.config.max_decode_length)) logger.debug("Decoder: pad the origin inputs into shape: %s", input_ids.shape) target_mask = np.zeros_like(input_ids) target_mask[:, 0] = 1 # As the decoder is generating from [START] token return encoder_output, encoder_mask, input_ids, target_mask def _pad_inputs_using_max_length(self, origin_inputs, pad_token_id=0): """pad the input_ids to the max_length""" pad_length = self.config.seq_length - origin_inputs.shape[-1] if pad_length < 0: raise ValueError( f"origin_inputs size is {origin_inputs.shape}, you should" f"increase the seq_length of the model {self.config.seq_length}." ) # Pad original inputs to model_origin_max_length input_ids = np.pad( origin_inputs, ((0, 0), (0, pad_length)), "constant", constant_values=(0, pad_token_id), ) return input_ids def _incremental_infer(self, model_inputs: dict, prefill, current_index, key_cache=None, value_cache=None): """model forward for incremental infer.""" # Claim the first graph if key_cache is not None and value_cache is not None: model_inputs = {**model_inputs, 'key_cache': key_cache, 'value_cache': value_cache} if prefill: self.phase = "prefill" if self._pre_set_phase: self.phase = f"prefill_{self._pre_set_phase}" # In dynamic shape scenarios, only the first execution of the prefill process will trigger this. if self._exec_add_flags: self.add_flags_custom(is_first_iteration=True) self.detailed_latency.start_predict_timer() # pylint: disable=E1102 res = self( **model_inputs, ) self.phase = "increment" # first iter done, go to other iters, in dynamic shape scenarios, only the first execution # of the increment process will trigger this. if self._exec_add_flags: self.add_flags_custom(is_first_iteration=False) if self.config.is_dynamic and not self.is_pynative: self._exec_add_flags = False else: # slice model inputs for incremental infer if self._pre_set_phase: self.phase = f"increment_{self._pre_set_phase}" if not (hasattr(self.config, 'parallel_decoding_params') and self.config.parallel_decoding_params): self.slice_incremental_inputs(model_inputs, current_index) self.detailed_latency.start_predict_timer() # pylint: disable=E1102 res = self( **model_inputs, ) return res def _incremental_infer_mcore(self, model_inputs: dict, prefill, gather_decode=True): r""" mcore model forward for incremental infer. Args: model_inputs: infer model inputs. prefill: flag to distinguish prefill and decode. gather_decode: whether to gather decode logits. Returns: res: the output logits. """ # Claim the first graph if prefill: self.phase = "prefill" if self._pre_set_phase: self.phase = f"prefill_{self._pre_set_phase}" # In dynamic shape scenarios, only the first execution of the prefill process will trigger this. if self._exec_add_flags: self.add_flags_custom_mcore(is_prefill=True) self.detailed_latency.start_predict_timer() # pylint: disable=E1102 res = self( **model_inputs, ) self.phase = "increment" # first iter done, go to other iters, in dynamic shape scenarios, only the first execution # of the increment process will trigger this. if self._exec_add_flags: self.add_flags_custom_mcore(is_prefill=False) self._exec_add_flags = False else: # slice model inputs for incremental infer if self._pre_set_phase: self.phase = f"increment_{self._pre_set_phase}" self.detailed_latency.start_predict_timer() # pylint: disable=E1102 res = self( **model_inputs, ) q_seq_lens = model_inputs.get("q_seq_lens", None) if gather_decode and q_seq_lens is not None: if q_seq_lens.max() > 1 and q_seq_lens.sum() == res.shape[0]: res = self.gather(res, mint.cumsum(q_seq_lens, dim=0) - 1, 0) return res def _beam_search(self, origin_inputs, beam_scorer: BeamSearchScorer, generation_config: GenerationConfig, logits_processor: Optional[LogitsProcessorList] = None, streamer: BaseStreamer = None, **model_kwargs): r""" Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: origin_inputs (`List(str), List(List(str))`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. generation_config (`GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default config from the model configuration will be used. Please note that unspecified parameters will inherit [`GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. streamer (`TextStreamer, *optional*`): The streamer that generator uses. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: A list of the generated token ids """ if streamer is not None: raise ValueError("Streamer does not support in beam search method yet!") if generation_config.use_past: raise ValueError("Beam search does not support incremental inference yet! Please set use_past to False.") if self.config.is_sample_acceleration: raise ValueError("Beam search does not support sample acceleration yet! " "Please set is_sample_acceleration to False.") total_time = time.time() prepare_time = time.time() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() batch_size = len(beam_scorer._beam_hyps) # pylint: disable=W0212 num_beams = beam_scorer.num_beams batch_beam_size = origin_inputs.shape[0] logger.debug("The input shape is: %s", origin_inputs.shape) if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) valid_length_each_example, _ = \ get_valid_length_each_example(origin_inputs, generation_config.pad_token_id) target_length = ( self.config.seq_length if generation_config.max_length > self.config.seq_length else generation_config.max_length ) logger.debug("max target_length is: %s", target_length) input_ids = self._pad_inputs_using_max_length( origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id ) logger.debug( "pad the origin inputs from %s into shape: %s", origin_inputs.shape, input_ids.shape, ) beam_scores = np.zeros((batch_size, num_beams), dtype=np.float64) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.reshape((batch_size * num_beams,)) input_mask = np.zeros_like(input_ids) for i in range(valid_length_each_example.shape[0]): input_mask[i, :valid_length_each_example[i]] = 1 encoder_output = None encoder_mask = None if self.config.is_encoder_decoder: if target_length > self.config.max_decode_length: target_length = self.config.max_decode_length logger.debug("target_length is: %s", target_length) # When do encoder and decoder prediction, the encoder can be cached # to speed up the inference ( encoder_output, encoder_mask, input_ids, target_mask, ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask) valid_length_each_example = np.ones((batch_beam_size, 1)).astype(np.int32) # update model kwargs once, before go into generate loop. self.update_model_kwargs_before_generate(input_ids, model_kwargs) need_gather_logits = True is_first_token = True origin_len = np.sum(valid_length_each_example) / num_beams prepare_time = time.time() - prepare_time logger.debug("forward prepare time: %s s", prepare_time) while True: forward_time = time.time() seq_length = input_ids.shape[1] current_index = [ valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_beam_size) ] logger.debug("validate length: %s", valid_length_each_example) if self.config.is_encoder_decoder: inputs = Tensor(input_ids, mstype.int32) # pylint: disable=E1102 res = self( input_ids=None, attention_mask=encoder_mask, encoder_outputs=encoder_output, decoder_input_ids=inputs, decoder_attention_mask=Tensor(target_mask, mstype.float32), ) else: model_kwargs["current_index"] = current_index # model prepare input dict model_inputs = self.prepare_inputs_for_generation( # pylint: disable=E1111 input_ids, **model_kwargs ) # incremental generate if generation_config.use_past: logger.warning("Beam search currently not support incremental, " "auto-aggressive generate will be performed.") # auto-aggressive generate res = self(**model_inputs) # pylint: disable=E1102 forward_time = time.time() - forward_time search_time = time.time() # post process logits # convert to numpy for post process logits = res[0] if isinstance(res, tuple) else res if isinstance(logits, Tensor): logits = logits.asnumpy().astype(np.float32) logits = np.reshape(logits, (-1, logits.shape[-1])) # (batch_size * num_beams * seq_length, vocab_size) # need gather last seq logits using current_index # compare length to determine if need gather; if not, gather should be done in model construct if need_gather_logits and logits.shape[0] > len(current_index): logits = logits[current_index] # (total_batch_size, vocab_size) logits_processor.append(LogitNormalization()) # post process logits, without changing logits shape and order next_token_scores = logits_processor(input_ids, logits) # (batch_size * num_beams, vocab_size) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = np.reshape(next_token_scores, (batch_size, -1)) # (batch_size, num_beams * vocab_size) if is_first_token: next_token_scores = next_token_scores[:, :vocab_size] is_first_token = False # sample 2 next tokens for each beam, so we have at least 1 non eos token per beam next_token_scores, next_tokens = topk( next_token_scores, 2 * num_beams, axis=1, largest=True, sort=True ) next_indices = np.floor_divide(next_tokens, vocab_size) next_tokens = next_tokens % vocab_size beam_outputs = beam_scorer.process( input_ids, # (batch_size * num_beams, seq_length) next_token_scores, next_tokens, next_indices, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] search_time = time.time() - search_time update_time = time.time() # reorder model inputs old_input_ids = input_ids.copy() for i in range(batch_beam_size): input_ids[i] = old_input_ids[beam_idx[i], :] # add new tokens to input_ids for i in range(batch_beam_size): input_ids[i, valid_length_each_example[i]] = beam_next_tokens[i] if self.config.is_encoder_decoder: target_mask[i][valid_length_each_example[i]] = int(1) input_mask[i][valid_length_each_example[i]] = 1 valid_length_each_example[i] += int(1) update_time = time.time() - update_time logger.debug("forward time: %s s; beam search time: %s s; update time: %s s; total count: %s s", forward_time, search_time, update_time, forward_time + search_time + update_time) if beam_scorer.is_done or np.min(valid_length_each_example) >= generation_config.max_length: break sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, max_length=generation_config.max_length ) generate_len = np.sum(valid_length_each_example) / num_beams - origin_len total_time = time.time() - total_time logger.info("total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s", total_time, generate_len, generate_len / total_time) return sequence_outputs["sequences"]
[docs] def generate(self, input_ids: Optional[Union[List[int], List[List[int]]]], generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, streamer: Optional[BaseStreamer] = None, seed: Optional[int] = None, **kwargs): r""" Generate the words according to the given the input ids. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, top_k=3, do_sample=True)`. Args: input_ids (List(str), List(List(str))): The token id list or a batch of token id list. When input a batch of token id list, the length of each token id list should be same. generation_config (`GenerationConfig`, optional): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default config from the model configuration will be used. Please note that unspecified parameters will inherit [`GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. Default: ``None``. logits_processor (`LogitsProcessorList`, optional): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. Default: ``None``. streamer (TextStreamer, optional): The streamer that generator uses. Default: ``None``. seed (int, optional): Random seed used in sample. Default: ``None``. kwargs: Specific parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. Supported `generate_config` keywords can be checked in [`GenerationConfig`]'s documentation. Mainly used Keywords are shown below: - max_length (int): The maximum length the generated tokens can have. Corresponds to the length of the input prompt + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - max_new_tokens (int): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - min_length (int): The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. - min_new_tokens (int): The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. - do_sample (bool): Whether to do sampling on the candidate ids. If set True it will be enabled, and set it to be False to disable the sampling, equivalent to top-k 1. If set None, it follows the setting in the configureation in the model. - top_k (int): Determine the top-k numbers token id as candidate. This should be a positive number. If set None, it follows the setting in the configureation in the model. - top_p (float): The accumulation probability of the candidate token ids below the top-p will be select as the condaite ids. The valid value of top-p is between (0, 1]. If the value is larger than 1, top-k algorithm will be enabled. If set None, it follows the setting in the configureation in the model. - eos_token_id (int): The end of sentence token id. If set None, it follows the setting in the configureation in the model. - pad_token_id (int): The pad token id. If set None, it follows the setting in the configureation in the model. - repetition_penalty (float): The penalty factor of the frequency that generated words. The If set 1, the repetition_penalty will not be enabled. If set None, it follows the setting in the configureation in the model. Default: ``None``. - num_beams (int): Number of beams for beam search. 1 means no beam search. If larger than 1, do_sample will be set to false. Returns: A list of the generated token ids. Examples: >>> from mindformers import LlamaForCausalLM, LlamaTokenizer >>> import mindspore as ms >>> ms.set_context(mode=0) >>> llama = LlamaForCausalLM.from_pretrained("llama2_7b") >>> tokenizer = LlamaTokenizer.from_pretrained("llama2_7b") >>> words = "translate the English to the Romanian: UN Chief Says There Is No Military Solution in Syria" >>> words = tokenizer(words, max_length=21, padding='max_length')['input_ids'] >>> output = llama.generate(words, do_sample=True) >>> output = tokenizer.decode(output[0], skip_special_tokens=True) >>> print(output) UN Chief Says There Is No Military Solution in Syria The United Nations Secretary-General, Ban Ki-moon, said that there is no military solution in Syria, calling on the international community >>> # Enable the top-p sampling >>> output = llama.generate(words, do_sample=True, top_p=0.4) >>> output = tokenizer.decode(output[0], skip_special_tokens=True) >>> print(output) UN Chief Says There Is No Military Solution in Syria UN Chief Says There Is No Military Solution in Syria. >>> # Enable the top-k sampling. >>> output = llama.generate(words, do_sample=True, top_k=10, top_p=1) >>> output = tokenizer.decode(output[0], skip_special_tokens=True) >>> print(output) Translation by: Adela Popa English Text: UN chief warns Syria conflict threatens entire region >>> from mindformers import LlamaForCausalLM, LlamaTokenizer >>> llama = LlamaForCausalLM.from_pretrained("llama2_7b") >>> tokenizer = LlamaTokenizer.from_pretrained("llama2_7b") >>> words = "translate the English to the Romanian: UN Chief Says There Is No Military Solution in Syria" >>> words = tokenizer(words, max_length=21, padding='max_length')['input_ids'] >>> output = llama.generate(words, num_beams=3) >>> output = tokenizer.decode(output[0], skip_special_tokens=True) >>> print(output) UN Chief Says There Is No Military Solution in Syria UN Chief Says There Is No Military Solution in Syria. """ self.detailed_latency.clear() origin_phase = self.phase self.set_train(False) try: input_ids = np.array(input_ids) except ValueError as e: raise ValueError(str(e) + " Please check your inputs of model.generate()," " and make sure the inputs are padded to same length.") from e input_ids = np.reshape(input_ids, (-1, np.shape(input_ids)[-1])) batch_size = input_ids.shape[0] if seed is not None: if not isinstance(seed, int): raise ValueError(f"Invalid seed type: {type(seed)}. Seed must be an integer.") if not 0 <= seed < 2**64: raise ValueError(f"Invalid seed value: {seed}. Seed must be in the range [0, 2**64 - 1].") np.random.seed(seed) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() # use_past should be defined in model config use_past_tmp = kwargs.pop("use_past", None) if use_past_tmp is not None: logger.warning("use_past should be defined in model config, it will not take effect when passed to " ".generate() method.") use_legacy = get_context("use_legacy", True) # Handle `generation_config` and kwargs that might update it # priority: `generation_config` argument > `model.generation_config` (default config) if generation_config is None: # legacy: users may modify the model configuration to control generation # model attribute accordingly, if it was created from the model config generation_config = GenerationConfig.from_model_config(self.config) generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update( **kwargs ) # All unused kwargs must be model kwargs if generation_config.num_beams > 1: logger.warning("When num_beams is set to a value greater than 1, do_sample will be set to False, " "due to the current beam search does not support sampling.") generation_config.do_sample = False logger.info("Generation Config is: %s", generation_config) if generation_config.pad_token_id is None: generation_config.pad_token_id = 0 valid_length_each_example, input_ids_length = \ get_valid_length_each_example(input_ids, generation_config.pad_token_id) if hasattr(self.config, "extend_method") and self.config.extend_method == "DYNAMIC_NTK": if not self.config.is_dynamic: raise ValueError("Dynamic NTK predict mode only support is_dynamic=True, but get is_dynamic=False") if generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_length if generation_config.max_length > self.config.seq_length: logger.warning("max_length %s can not exceeds model seq_length %s, set max_length = seq_length.", generation_config.max_length, self.config.seq_length) generation_config.max_length = self.config.seq_length logger.debug("max length is: %s", generation_config.max_length) if not self.config.is_encoder_decoder and input_ids_length > generation_config.max_length: raise ValueError( "The max_length set is smaller than the length in the input_ids." f"You shout set max_length to {input_ids_length}" ) if generation_config.max_new_tokens is not None: max_length_each_example = [valid_length + generation_config.max_new_tokens \ for valid_length in valid_length_each_example] else: max_length_each_example = [generation_config.max_length] * len(valid_length_each_example) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: logger.warning(f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is " f"larger than the maximum possible length ({generation_config.max_length})." f" Generation will stop at the defined maximum length. " f"You should decrease the minimum length and/or increase the maximum length.") if generation_config.min_new_tokens is not None: min_length = generation_config.min_new_tokens + input_ids_length if min_length > generation_config.max_length: logger.warning( f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " f"added to the prompt length ({input_ids_length}), is larger than" f" the maximum possible length ({generation_config.max_length}). " f"Generation will stop at the defined maximum length. " f"You should decrease the minimum length and/or increase the maximum length." ) logits_processor = self.get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, logits_processor=logits_processor, ) # determine generation mode generation_config.generation_mode = self._get_generation_mode(generation_config) logger.info(f"The generation mode will be **{generation_config.generation_mode.upper()}**.") if streamer is not None and (generation_config.num_beams > 1): raise ValueError( "`streamer` cannot be used with beam search yet. Make sure that `num_beams` is set to 1." ) if not use_legacy: self._set_block_mgr(batch_size, self.config.seq_length) self.set_dynamic_inputs() elif generation_config.use_past: self._set_block_mgr(batch_size, self.config.seq_length) if self.config.is_dynamic: self.set_dynamic_inputs() # prepare dict outputs if generation_config.return_dict_in_generate and generation_config.output_logits \ and self.config.is_sample_acceleration: logger.warning("When `is_sample_acceleration` is True, logits can not be fetched. " "Set `output_logits` to False.") generation_config.output_logits = False scores = () if generation_config.return_dict_in_generate and generation_config.output_scores else None raw_logits = () if generation_config.return_dict_in_generate and generation_config.output_logits else None # beam search if generation_config.generation_mode == GenerationMode.BEAM_SEARCH: # prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, max_length=generation_config.max_length ) # interleave input_ids with `num_beams` additional sequences per batch input_ids = np.repeat(input_ids, generation_config.num_beams, 0) # run beam search output_ids = self._beam_search( origin_inputs=input_ids, beam_scorer=beam_scorer, generation_config=generation_config, logits_processor=logits_processor, streamer=streamer, **model_kwargs ) # greedy search or sample else: total_time = time.time() prepare_time = time.time() origin_inputs = input_ids logits_warper = self.get_logits_warper(generation_config) \ if generation_config.generation_mode == GenerationMode.SAMPLE else None if streamer is not None: streamer.put(origin_inputs) batch_size = origin_inputs.shape[0] logger.debug("The input shape is: %s", origin_inputs.shape) valid_length_each_example, _ = \ get_valid_length_each_example(origin_inputs, generation_config.pad_token_id) input_ids = self._pad_inputs_using_max_length( origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id ) logger.debug( "pad the origin inputs from %s into shape: %s", origin_inputs.shape, input_ids.shape, ) input_mask = np.zeros_like(input_ids) for i in range(valid_length_each_example.shape[0]): input_mask[i, :valid_length_each_example[i]] = 1 encoder_output = None encoder_mask = None target_mask = None if self.config.is_encoder_decoder: if generation_config.max_length > self.config.max_decode_length: generation_config.max_length = self.config.max_decode_length logger.debug("max decode length is: %s", generation_config.max_length) # When do encoder and decoder prediction, the encoder can be cached # to speed up the inference ( encoder_output, encoder_mask, input_ids, target_mask, ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask) valid_length_each_example = np.array([1 for _ in range(batch_size)]) # A single loop generates one token, loop until reaching target # model_origin_max_length or generating eod token is_finished = [False] * batch_size # update model kwargs once, before go into generate loop. self.update_model_kwargs_before_generate(input_ids, model_kwargs) origin_len = np.sum(valid_length_each_example) prepare_time = time.time() - prepare_time logger.debug("forward prepare time: %s s", prepare_time) prefill = True model_kwargs["origin_inputs"] = origin_inputs if hasattr(self.config, 'pet_config') and self.config.pet_config.pet_type == "slora": adapter_id = kwargs.pop("adapter_id", None) if adapter_id is not None and len(adapter_id) > 1: if len(adapter_id) != batch_size: raise ValueError("adapter_ids has different length with inputs.") model_kwargs["adapter_ids"] = adapter_id else: model_kwargs["adapter_ids"] = adapter_id * batch_size if adapter_id is not None else None while np.sum(is_finished) != batch_size: self.detailed_latency.start_preprocess_timer() block_tables = None slot_mapping = None if not use_legacy or generation_config.use_past: if prefill: if (use_legacy and self.is_pynative and self.config.is_dynamic): max_input_length = len(origin_inputs[0]) else: max_input_length = self.config.seq_length block_tables, slot_mapping = self.block_mgr.assemble_pa_full_inputs(max_input_length, valid_length_each_example, is_finished) else: block_tables, slot_mapping = self.block_mgr.assemble_pa_inc_inputs(valid_length_each_example, is_finished) self.profile.start_profiling(valid_length_each_example[0] - input_ids_length) if use_legacy: infer_output, is_finished = self.infer(input_ids=input_ids, valid_length_each_example=valid_length_each_example, generation_config=generation_config, logits_processor=logits_processor, logits_warper=logits_warper, block_tables=block_tables, slot_mapping=slot_mapping, prefill=prefill, is_finished=is_finished, encoder_mask=encoder_mask, encoder_output=encoder_output, target_mask=target_mask, **model_kwargs) else: infer_output, is_finished = self.infer_mcore(input_ids=input_ids, valid_length_each_example=valid_length_each_example, generation_config=generation_config, logits_processor=logits_processor, logits_warper=logits_warper, block_tables=block_tables, slot_mapping=slot_mapping, prefill=prefill, is_finished=is_finished, **model_kwargs) self.profile.stop_profiling(valid_length_each_example[0] - input_ids_length) if generation_config.return_dict_in_generate: target_list = infer_output["target_list"] if generation_config.output_scores: scores += (infer_output["probs"],) if generation_config.output_logits: raw_logits += (infer_output["logits"],) else: target_list = infer_output if not use_legacy or generation_config.use_past: if prefill and "origin_inputs" in model_kwargs: model_kwargs.pop("origin_inputs") prefill = False for i in range(batch_size): if is_finished[i]: continue input_ids[i, valid_length_each_example[i]] = target_list[i] if self.config.is_encoder_decoder: target_mask[i][valid_length_each_example[i]] = int(1) # Stop judgment if target_list[i] in generation_config.eos_token_id \ or valid_length_each_example[i] + 1 == generation_config.max_length \ or valid_length_each_example[i] + 1 == max_length_each_example[i]: is_finished[i] = True else: valid_length_each_example[i] += 1 input_mask[i][valid_length_each_example[i] - 1] = 1 if streamer is not None: if batch_size == 1: streamer.put(target_list[0]) else: streamer.put(target_list) self.detailed_latency.end_postprocess_timer() # Return valid outputs out of padded outputs valid_length_each_example += 1 output_ids = [] for i in range(batch_size): output_ids.append( input_ids[i, : int(valid_length_each_example[i])].astype(np.int32) ) logger.debug("The output is: %s", output_ids) if streamer is not None: streamer.end() generate_len = np.sum(valid_length_each_example) - origin_len total_time = time.time() - total_time logger.info("total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s", total_time, generate_len, generate_len / total_time) self.detailed_latency.print_info() # set to original phase self.set_train(origin_phase == "train") if self.block_mgr: self.block_mgr.clear_cache() if generation_config.return_dict_in_generate: result = GenerateOutput( sequences=output_ids, scores=scores, logits=raw_logits ) return result return output_ids
[docs] def infer(self, input_ids: Union[List[int], List[List[int]]], valid_length_each_example: np.ndarray, generation_config: GenerationConfig = None, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, block_tables: Optional[Tensor] = None, slot_mapping: Optional[Tensor] = None, prefill: bool = True, is_finished: List[bool] = None, encoder_mask: Optional[Tensor] = None, encoder_output: Optional[Tensor] = None, target_mask: Optional[Tensor] = None, **model_kwargs): r""" Do infer and return logits on next position, can choose do prefill or decode predict. Args: input_ids (List(List(int))): Input ids after padding. valid_length_each_example (np.ndarray): Valid input length except padding. generation_config (`GenerationConfig`, optional): The generation configuration to be used as base parametrization for the generation call. Default: ``None``. logits_processor (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. Default: ``None``. logits_warper (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. Default: ``None``. block_tables (Tensor, optional): Store mapping tables for each sequence. Default: ``None``. slot_mapping (Tensor, optional): Token cache physical slot index. Default: ``None``. prefill (bool, optional): Whether to do prefill predict or decode predict. Default: ``True``. is_finished (List(bool), optional): Whether each sequence is finished its generation. Default: ``None``. encoder_mask (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. encoder_output (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. target_mask (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. **model_kwargs (Any): Keyword arguments of the model. Returns: next_token, the next token to be generated. is_finished, whether the sequence has completed its generation task. """ max_valid_length = max(valid_length_each_example) if not self.config.is_encoder_decoder and max_valid_length > self.config.seq_length: raise ValueError( f"The input length:{max_valid_length} is longer than the seq_length:{self.config.seq_length}, " "which is not allowed." ) start_time = time.time() input_ids = np.array(input_ids) res, current_index = self.forward(input_ids=input_ids, valid_length_each_example=valid_length_each_example, block_tables=block_tables, slot_mapping=slot_mapping, prefill=prefill, use_past=generation_config.use_past, encoder_mask=encoder_mask, encoder_output=encoder_output, target_mask=target_mask, **model_kwargs) self.detailed_latency.start_postprocess_timer() forward_time = time.time() - start_time sample_time = time.time() need_gather_logits = True if not self.config.is_encoder_decoder and generation_config.use_past: need_gather_logits = prefill target_list, probs, logits, is_finished = self.postprocess( input_ids=input_ids, is_finished=is_finished, res=res, generation_config=generation_config, valid_length_each_example=valid_length_each_example, current_index=current_index, logits_processor=logits_processor, logits_warper=logits_warper, need_gather_logits=need_gather_logits ) sample_time = time.time() - sample_time infer_time = time.time() - start_time logger.debug("forward time: %s s; sample time: %s s; total count: %s s", forward_time, sample_time, infer_time) if generation_config.return_dict_in_generate: infer_output_dict = InferOutput( target_list=target_list, probs=probs, logits=logits ) return infer_output_dict, is_finished return target_list, is_finished
[docs] def forward(self, input_ids: [Union[List[int], List[List[int]]]], valid_length_each_example: np.ndarray, block_tables: Optional[Tensor] = None, slot_mapping: Optional[Tensor] = None, prefill: bool = None, use_past: bool = False, encoder_mask: Optional[Tensor] = None, encoder_output: Optional[Tensor] = None, target_mask: Optional[Tensor] = None, key_cache: Optional[List[Tensor]] = None, value_cache: Optional[List[Tensor]] = None, **model_kwargs): r""" Model forward process. Args: input_ids (List[List[int]]): Input ids after padding. valid_length_each_example (np.ndarray): Valid input length except padding. block_tables (Tensor, optional): Params for page attention. Default: ``None``. slot_mapping (Tensor, optional): Params for page attention. Default: ``None``. prefill (bool, optional): Whether to do prefill predict or decode predict. Default: ``None``. use_past (bool, optional): Whether to use past. Default: ``False``. encoder_mask (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. encoder_output (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. target_mask (Tensor, optional): Use for encoder-decoder construct, do not need for decoder only construct. Default: ``None``. key_cache (List[Tensor], optional): A group of tensors used for kvcache. Default: ``None``. value_cache (List[Tensor], optional): A group of tensors used for kvcache. Default: ``None``. **model_kwargs (Any): Keyword arguments of the model. Returns: res, the result after the forward process. current_index, records the current index of the sequence. """ input_ids = np.reshape(input_ids, (-1, np.shape(input_ids)[-1])) if self.config.is_encoder_decoder: inputs = Tensor(input_ids, mstype.int32) # pylint: disable=E1102 res = self( input_ids=None, attention_mask=encoder_mask, encoder_outputs=encoder_output, decoder_input_ids=inputs, decoder_attention_mask=Tensor(target_mask, mstype.float32), ) else: if parallel_decoding_control(self.config): current_index = None else: current_index = valid_length_each_example - 1 + np.arange(input_ids.size, step=input_ids.shape[1]) model_kwargs["current_index"] = current_index model_kwargs["prefill"] = prefill if use_past else None model_kwargs["valid_length_each_example"] = valid_length_each_example model_kwargs["block_tables"] = block_tables model_kwargs["slot_mapping"] = slot_mapping # pylint: disable=E1111 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) real_input_ids = model_inputs["input_ids"] if parallel_decoding_control(self.config): model_inputs, block_tables, slot_mapping = parallel_decoding_process( self.config, input_ids, model_inputs, **model_kwargs ) else: current_index = valid_length_each_example - 1 + np.arange(real_input_ids.numel(), step=real_input_ids.shape[1]) if use_past: if "batch_valid_length" not in model_inputs: model_inputs["batch_valid_length"] = Tensor.from_numpy( np.array([valid_length_each_example], dtype=np.int32)) if block_tables is not None and "block_tables" not in model_inputs: model_inputs["block_tables"] = Tensor.from_numpy(block_tables) if slot_mapping is not None and "slot_mapping" not in model_inputs: model_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping) res = self._incremental_infer( model_inputs=model_inputs, prefill=prefill, current_index=current_index, key_cache=key_cache, value_cache=value_cache ) else: if self._pre_set_phase: self.phase = f"predict_{self._pre_set_phase}" res = self(**model_inputs) # pylint: disable=E1102 return res, current_index
def prepare_inputs_for_generation_mcore(self, input_ids: [Union[List[int], List[List[int]]]], valid_length_each_example: np.ndarray, block_tables: Optional[Tensor] = None, slot_mapping: Optional[Tensor] = None, prefill: bool = None, **model_kwargs): """prepare inputs for mcore""" model_inputs = dict() seq_lens = np.array(valid_length_each_example) q_seq_lens = model_kwargs.get("q_seq_lens", None) positions = model_kwargs.get("position_ids", None) attention_mask = model_kwargs.get("attention_mask", None) if q_seq_lens is None or np.size(q_seq_lens) == 0: if len(input_ids) == len(seq_lens): q_seq_lens = np.ones_like(seq_lens) else: q_seq_lens = np.array(valid_length_each_example) q_seq_lens = np.array(q_seq_lens) if prefill and len(input_ids) != q_seq_lens.sum(): q_seq_lens = np.array(valid_length_each_example) context_lens = seq_lens - q_seq_lens if positions is None: positions = np.zeros_like(input_ids, dtype=np.int32) start = 0 for i in range(seq_lens.size): positions[start:start + q_seq_lens[i]] = np.arange(context_lens[i], seq_lens[i]) start += q_seq_lens[i] if prefill and context_lens.max() > 0: prefill = False model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32)) model_inputs["batch_valid_length"] = Tensor.from_numpy(seq_lens.astype(np.int32)) model_inputs["context_lens_tensor"] = Tensor.from_numpy(context_lens.astype(np.int32)) model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens.astype(np.int32)) model_inputs["positions"] = Tensor.from_numpy(positions.astype(np.int32)) model_inputs["block_tables"] = Tensor.from_numpy(block_tables) model_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping) if attention_mask is not None: if isinstance(attention_mask, np.ndarray): attention_mask = Tensor.from_numpy(attention_mask) model_inputs["attention_mask"] = attention_mask.astype(self.config.compute_dtype) else: model_inputs["attention_mask"] = None model_inputs["attn_metadata"] = None model_inputs["kv_cache"] = None return model_inputs, prefill def forward_mcore(self, input_ids: [Union[List[int], List[List[int]]]], valid_length_each_example: np.ndarray, block_tables: Optional[Tensor] = None, slot_mapping: Optional[Tensor] = None, prefill: bool = None, **model_kwargs): r""" Model forward process. Args: input_ids (List(List(int))): Input ids after padding. valid_length_each_example (np.ndarray): Valid input length except padding. block_tables (Tensor, optional): Params for page attention. Default: ``None``. slot_mapping (Tensor, optional): Params for page attention. Default: ``None``. prefill (bool, optional): Whether to do prefill predict or decode predict. Default: ``None``. **model_kwargs (Any): Keyword arguments of the model. Returns: res, the result after the forward process. current_index, records the current index of the sequence. """ attention_mask = None gather_decode = True if isinstance(self.config.parallel_decoding_params, Dict): plugin_type = self.config.parallel_decoding_params.get("plugin_type") else: plugin_type = None if plugin_type == "la": slot_mapping, attention_mask = la_pre_process(input_ids, slot_mapping, **model_kwargs) model_kwargs["attention_mask"] = attention_mask # lookahead should not gather decode logits gather_decode = False model_inputs, prefill = self.prepare_inputs_for_generation_mcore( input_ids=input_ids, valid_length_each_example=valid_length_each_example, block_tables=block_tables, slot_mapping=slot_mapping, prefill=prefill, **model_kwargs, ) res = self._incremental_infer_mcore( model_inputs=model_inputs, prefill=prefill, gather_decode=gather_decode ) return res, None def infer_mcore(self, input_ids: Union[List[int], List[List[int]]], valid_length_each_example: np.ndarray, generation_config: GenerationConfig = None, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, block_tables: Optional[Tensor] = None, slot_mapping: Optional[Tensor] = None, prefill: bool = True, is_finished: List[bool] = None, **model_kwargs): r""" Do infer and return logits on next position, can choose do prefill or decode predict. Args: input_ids (List(List(int))): Input ids after padding. valid_length_each_example (np.ndarray): Valid input length except padding. generation_config (`GenerationConfig`, optional): The generation configuration to be used as base parametrization for the generation call. Default: ``None``. logits_processor (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. Default: ``None``. logits_warper (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. Default: ``None``. block_tables (Tensor, optional): Store mapping tables for each sequence. Default: ``None``. slot_mapping (Tensor, optional): Token cache physical slot index. Default: ``None``. prefill (bool, optional): Whether to do prefill predict or decode predict. Default: ``True``. is_finished (List(bool), optional): Whether each sequence is finished its generation. Default: ``None``. **model_kwargs (Any): Keyword arguments of the model. Returns: next_token, the next token to be generated. is_finished, whether the sequence has completed its generation task. """ max_valid_length = max(valid_length_each_example) if max_valid_length > self.config.seq_length: raise ValueError( f"The input length:{max_valid_length} is longer than the seq_length:{self.config.seq_length}, " "which is not allowed." ) start_time = time.time() flatten_input_ids, slot_mapping = self._prepare_inputs_for_flatten( input_ids, valid_length_each_example, slot_mapping, prefill ) res, current_index = self.forward_mcore( input_ids=flatten_input_ids, valid_length_each_example=valid_length_each_example, block_tables=block_tables, slot_mapping=slot_mapping, prefill=prefill, **model_kwargs, ) self.detailed_latency.start_postprocess_timer() forward_time = time.time() - start_time sample_time = time.time() target_list, probs, logits, is_finished = self.postprocess( input_ids=input_ids, is_finished=is_finished, res=res, current_index=current_index, generation_config=generation_config, valid_length_each_example=valid_length_each_example, logits_processor=logits_processor, logits_warper=logits_warper, need_gather_logits=False, ) sample_time = time.time() - sample_time infer_time = time.time() - start_time logger.debug("forward time: %s s; sample time: %s s; total count: %s s", forward_time, sample_time, infer_time) if generation_config.return_dict_in_generate: infer_output_dict = InferOutput( target_list=target_list, probs=probs, logits=logits ) return infer_output_dict, is_finished return target_list, is_finished def _prepare_inputs_for_flatten(self, input_ids, valid_length_each_example, slot_mapping, prefill=True): """prepare inputs ids for prefill flatten""" input_ids = np.array(input_ids) batch_valid_length_bs = valid_length_each_example.shape[0] if prefill: input_ids_list = [] for i in range(batch_valid_length_bs): input_ids_list.append(input_ids[i][:valid_length_each_example[i]]) input_ids = np.concatenate(input_ids_list, 0) slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1)) else: batch_valid_length_bs = valid_length_each_example.shape[0] input_ids_list = [] for i in range(batch_valid_length_bs): input_ids_list.append(input_ids[i][valid_length_each_example[i] - 1]) input_ids = np.array(input_ids_list) input_ids = input_ids.reshape((-1)) return input_ids, slot_mapping # pylint: disable=E1102 def chunk_prefill_infer(self, input_ids: [Union[List[int], List[List[int]]]], batch_valid_length: np.ndarray, block_tables: np.ndarray, slot_mapping: np.ndarray, attention_mask: Optional[np.ndarray] = None, **model_kwargs): """ Preprocessing of chunk prefill inference Args: input_ids (List(List(int))): Input ids. batch_valid_length (np.ndarray): Valid input length. block_tables (np.ndarray): Params for page attention. slot_mapping (np.ndarray): Params for page attention. attention_mask (np.ndarray): Params for page attention. q_seq_lens (np.ndarray): Params for page attention. gather_index (np.ndarray): Used to obtain the last latent vector of each sequence. seq_range (np.ndarray): Used to obtain Mask and positional encoding of valid tokens for each sequence. """ if not (self.use_past and self.chunk_prefill): raise ValueError(f"chunk prefill infer can be called only when use_past=true and chunk_prefill=true, \ but use_past={self.use_past}, chunk_prefill={self.chunk_prefill}") # decode if "gather_index" not in model_kwargs or "seq_range"not in model_kwargs \ or "q_seq_lens" not in model_kwargs: model_kwargs["gather_index"] = None model_kwargs["seq_range"] = None model_kwargs["q_seq_lens"] = None self.add_flags_custom(is_first_iteration=False) else: # decode + chunk input_ids = np.reshape(input_ids, (1, -1)) model_kwargs["gather_index"] = Tensor(model_kwargs["gather_index"], ms.int32) model_kwargs["seq_range"] = Tensor(model_kwargs["seq_range"], ms.int32) model_kwargs["q_seq_lens"] = Tensor(model_kwargs["q_seq_lens"], ms.int32) self.add_flags_custom(is_first_iteration=True) if attention_mask is not None: model_kwargs["attention_mask"] = Tensor(attention_mask, ms.float16) model_kwargs["input_ids"] = Tensor(input_ids, ms.int32) model_kwargs["batch_valid_length"] = Tensor(batch_valid_length, ms.int32) model_kwargs["block_tables"] = Tensor(block_tables, ms.int32) model_kwargs["slot_mapping"] = Tensor(slot_mapping, ms.int32) logits = self(**model_kwargs) return logits
[docs] def postprocess(self, input_ids, is_finished, res, generation_config: GenerationConfig, valid_length_each_example, current_index: Optional[Union[List[int], List[List[int]]]], logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, need_gather_logits: bool = True): r""" Postprocess of the output from model generation. Args: input_ids (List(List(int))): Input ids after padding. res (List(List(int))): Logits after infer. is_finished (List(bool)): Whether each sequence is finished its generation. generation_config (`GenerationConfig`): The generation configuration to be used as base parametrization for the generation call. valid_length_each_example (np.ndarray): Valid input length except padding. current_index (List(int)): Current index of sequence. logits_processor (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. Default: ``None``. logits_warper (`LogitsProcessorList`, optional): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. Default: ``None``. need_gather_logits (bool, optional): whether gather result, when decode predict and is first iteration. Default: ``True``. Returns: target_list, contains the target values generated in each batch. next_probs_cache, cache for probs, if needed in output. next_logits_cache, cache for logits, if needed in output. is_finished, whether the sequence has completed its generation task. """ if self.use_mint_op and not self.is_pynative: from mindspore.common.api import _pynative_executor _pynative_executor.set_async_for_graph(True) batch_size = input_ids.shape[0] target_list = [[] for _ in range(batch_size)] # cache for logits and probs, if needed in output next_logits_cache = None next_probs_cache = None generation_config.generation_mode = self._get_generation_mode(generation_config) if generation_config.generation_mode == GenerationMode.GREEDY_SEARCH: if not self.config.is_sample_acceleration: logits = res[0] if isinstance(res, tuple) else res logits = logits.reshape(-1, logits.shape[-1]) if need_gather_logits and logits.shape[0] > len(current_index): logits = logits[Tensor(current_index, dtype=mstype.int32)] # store caced logits if generation_config.return_dict_in_generate and generation_config.output_logits: if isinstance(logits, Tensor): next_logits_cache = logits.asnumpy().copy() else: next_logits_cache = logits.copy() if logits_processor: if isinstance(logits, Tensor): logits = logits.asnumpy() logits = Tensor(logits_processor(input_ids, logits, is_finished)) # store caced probs if generation_config.return_dict_in_generate and generation_config.output_scores: if isinstance(logits, Tensor): next_probs_cache = logits.asnumpy().copy() else: next_probs_cache = logits.copy() target_list = self.argmax(logits, -1) target_list = target_list.asnumpy().tolist() else: probs, p_args = res if isinstance(p_args, Tensor): p_args = p_args.asnumpy() # store caced probs if generation_config.return_dict_in_generate and generation_config.output_scores: if isinstance(probs, Tensor): next_probs_cache = probs.asnumpy().copy() else: next_probs_cache = probs.copy() target_index_list = P.Argmax()(probs) target_index_list = target_index_list.asnumpy().tolist() # run greedy search for i in range(batch_size): if is_finished[i]: continue target_index = target_index_list[i] target = p_args[i][target_index] target_list[i] = target elif generation_config.generation_mode == GenerationMode.SAMPLE: if not self.config.is_sample_acceleration: # convert to numpy for post process logits = res[0] if isinstance(res, tuple) else res if isinstance(logits, Tensor): logits = logits.asnumpy() logits = np.reshape(logits, (-1, logits.shape[-1])) # need gather last seq logits using current_index # compare length to determine if need gather; if not, gather should be done in model construct if need_gather_logits and logits.shape[0] > len(current_index): logits = logits[current_index] # store caced logits if generation_config.return_dict_in_generate and generation_config.output_logits: next_logits_cache = logits.copy() probs = logits_processor(input_ids, logits, is_finished) p_args = np.tile(np.arange(logits.shape[-1]), (batch_size, 1)) probs = logits_warper(input_ids, probs, is_finished) else: probs, p_args = res if isinstance(probs, Tensor): probs = probs.asnumpy() if isinstance(p_args, Tensor): p_args = p_args.asnumpy() # store caced probs if generation_config.return_dict_in_generate and generation_config.output_scores: next_probs_cache = probs.copy() p_norms = softmax_with_threads(probs, is_finished) for i in range(batch_size): if is_finished[i]: continue p_norm = p_norms[i] target_index = np.random.choice(len(probs[i]), p=p_norm) # get target token id target = p_args[i][target_index] target_list[i] = target elif generation_config.generation_mode == GenerationMode.BEAM_SEARCH: raise ValueError("sampler method doesn't support BEAM_SEARCH. ") if self.use_mint_op and not self.is_pynative: from mindspore.common.api import _pynative_executor _pynative_executor.sync() _pynative_executor.set_async_for_graph(False) return target_list, next_probs_cache, next_logits_cache, is_finished
[docs] def chat(self, tokenizer: PreTrainedTokenizer, query: str, history: Optional[List[Dict[str, str]]] = None, system_role_name: Optional[str] = "system", user_role_name: Optional[str] = "user", assistant_role_name: Optional[str] = "assistant", instruction: Optional[str] = "", max_length: Optional[int] = 512, max_new_tokens: Optional[int] = None, min_length: Optional[int] = 0, min_new_tokens: Optional[int] = None, do_sample: Optional[bool] = True, temperature: Optional[float] = 1.0, top_k: Optional[int] = 50, top_p: Optional[float] = 1.0, repetition_penalty: Optional[float] = 1.0): r""" Dia-logical text generation inference with large language models. The query from the user will be inference using generate() after adding the chat template via the provided tokenizer. Args: tokenizer (PreTrainedTokenizer): The tokenized used to decode the tokens. query (str): User input for inference. history (List[Dict[str, str]], optional): A Conversation object or list of dicts with "role" and "content" keys, representing the chat history so far. Default: ``None``. system_role_name (str, optional): The name of system role. Default: ``"system"``. user_role_name (str, optional): The name of user role. Default: ``"user"``. assistant_role_name (str, optional): The name of assistant role. Default: "assistant". instruction (str, optional): Instruction message to the model. Default: ``""``. max_length (int, optional): The maximum length the generated tokens can have. Corresponds to the length of the input prompt + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. Default: ``512``. max_new_tokens (int, optional): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. Default: ``None``. min_length (int, optional): The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. Default: 0. min_new_tokens (int, optional): The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. Default: ``None``. do_sample (bool, optional): Whether to do sampling on the candidate ids. If set True it will be enabled, and set it to be False to disable the sampling, equivalent to top-k 1. If set None, it follows the setting in the configuration in the model. Default: ``True``. temperature (float, optional): The value used to modulate the next token probabilities. Default: ``1.0``. top_k (int, optional): Determine the top-k numbers token id as candidate. This should be a positive number. If set None, it follows the setting in the configuration in the model. Default: ``50``. top_p (float, optional): The accumulation probability of the candidate token ids below the top-p will be select as the candidate ids. The valid value of top-p is between (0, 1]. If the value is larger than 1, top-k algorithm will be enabled. If set None, it follows the setting in the configuration in the model. Default: ``1.0``. repetition_penalty (float, optional): The penalty factor of the frequency that generated words. If set 1, the repetition_penalty will not be enabled. If set None, it follows the setting in the configuration in the model. Default: ``1.0``. Returns: response, the reply from the LLM in this session. history, the conversation history. Examples: >>> import mindspore as ms >>> from mindformers.generation import text_generator >>> from mindformers import AutoModel, AutoTokenizer >>> ms.set_context(mode=0) >>> model = AutoModel.from_pretrained("llama2_7b") >>> tokenizer = AutoTokenizer.from_pretrained("llama2_7b") >>> query = "Hello!" >>> response, history = model.chat(tokenizer=tokenizer, query=query, max_length=32) >>> print(response) Thanks, sir. """ if history is None: history = [] if instruction: history.append({"role": system_role_name, "content": instruction}) history.append({"role": user_role_name, "content": query}) input_ids = tokenizer.apply_chat_template(conversation=history, add_generation_prompt=True) output_ids = self.generate(input_ids=input_ids, max_length=max_length, max_new_tokens=max_new_tokens, min_length=min_length, min_new_tokens=min_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty) output_ids = output_ids[0][len(input_ids):] response = tokenizer.decode(output_ids, skip_special_tokens=True) history.append({"role": assistant_role_name, "content": response}) return response, history