Building an LLM Inference Network from Scratch
Model Development Modes
MindSpore provides two model running modes:
Static graph mode: The model network is compiled into a complete network graph for convergence and optimization, improving the model execution performance. However, due to some syntax support issues, model development has certain limitations, affecting the usability.
Dynamic graph mode: Python statements of network scripts are executed one by one, facilitating printing and debugging (by using the PDB) at any time. This mode is easy to use, but its performance is not as good as that of the static graph mode.
In MindSpore, you are advised to use the dynamic graph mode to develop a model and then convert dynamic graphs to static graphs as required to obtain the maximum model performance.
Backbone Network Used for Development in Dynamic Graph Mode
Most mainstream LLMs use the Transformer-based backbone network, where core computing relies on the self-attention mechanism. The following figure uses the Qwen2 LLM as an example to show the backbone network architecture.
The core layer of Qwen2 consists of the following parts:
Embedding: converts the index corresponding to each token into a vector to implement feature dispersion. Similar to one-hot vectorization, the embedding weights are involved in the training process, which can better adapt to the context semantics in the LLM. This process is implemented through the embedding operator.
DecodeLayer: refers to the Transformer structure, which is a key compute module of the LLM. Generally, multiple layers of computation are configured as needed. Each layer is actually a Transformer structure.
RmsNorm & Linear: linearly normalizes the output of each layer to the same dimension as the model vocabulary after computation by the transformer structure and returns the probability distribution of each token.
You can use the MindSpore LLM to build a network for inference. The network can be assembled as required using operators provided by MindSpore. The following uses the Qwen2 model as an example to describe how to build a model. For details about the complete end-to-end example, see qwen2.py.
Basic Common Network Layer
The Qwen2 LLM has many configurations and parameters. To manage them more conveniently, you need to define the Config and Input classes to be used by the model. In addition, note that the Linear and RmsNorm operators are frequently used in each functional layer of the network. You can build these common layers in advance.
Config & Input
import json
from dataclasses import dataclass
from typing import Optional, Type, List, Tuple, Union
from mindspore import Tensor, dtype
@dataclass
class Qwen2Config:
"""Qwen2 Config, the key-value is almost the same with config.json in Hugging Face"""
architectures: Optional[List[str]] = None
attention_dropout: float = 0.0
bos_token_id: int = 151643
eos_token_id: int = 151645
hidden_act: str = "silu"
hidden_size: int = 3584
initializer_range: float = 0.02
intermediate_size: int = 18944
max_position_embeddings: int = 32768
max_window_layers: int = 28
model_type: str = "qwen2"
num_attention_heads: int = 28
num_hidden_layers: int = 28
num_key_value_heads: int = 4
rms_norm_eps: float = 1e-06
rope_theta: float = 1000000.0
sliding_window: Optional[int] = 131072
tie_word_embeddings: bool = False
torch_dtype: str = "bfloat16"
transformers_version: str = "4.41.2"
use_cache: bool = True
use_sliding_window: bool = False
vocab_size: int = 152064
param_dtype: Optional[Type] = dtype.bfloat16 # this is mindspore datatype as hugging face use str as dtype
@classmethod
def from_json(cls, json_path: str) -> 'Qwen2Config':
with open(json_path) as f:
data = json.load(f)
config = cls(**data)
return config
@dataclass
class Qwen2ModelInput:
input_ids: Tensor
positions: Tensor
batch_valid_length: Tensor
is_prefill: bool
attn_mask: Tensor
k_caches: List[Tensor]
v_caches: List[Tensor]
slot_mapping: Tensor = None
block_tables: Tensor = None
hidden_state: Optional[Tensor] = None
residual: Optional[Tensor] = None
q_seq_lens: Optional[Tensor] = None
The Qwen2Config configuration is basically the same as that of Hugging Face. For details, see the official Qwen2 documentation. Note that param_dtype is used to replace torch_dtype in Qwen2Config because the data types of MindSpore are different from those of PyTorch. Qwen2ModelInput defines the model input, including the word ID, KVCache, and Attention fused operator, which are required by MindSpore inference optimization features.
RmsNorm
RmsNorm is a normalization algorithm commonly used in most LLMs. MindSpore provides operators that can be directly used. You only need to create the corresponding weights. In addition, RmsNorm often involves residual computing. The RmsNorm class implements residual converged computing at the network layer. The following is a code example:
from typing import Optional, Type, Union, Tuple
from mindspore import nn, ops, mint, Parameter, Tensor
class RmsNorm(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.rms_norm = ops.RmsNorm(config.rms_norm_eps)
self.weight = Parameter(
mint.ones(
config.hidden_size,
dtype=config.param_dtype
),
requires_grad=False
)
def construct(self, x: Tensor, residual: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if residual is not None:
x = x + residual
residual = x
output = self.rms_norm(x, self.weight)[0]
if residual is None:
return output
return output, residual
Linear
The Linear layer is actually a linear transformation. Its main computing logic is matrix multiplication (MatMul). However, bias correction may be required for addition depending on the specific application scenario (bias is required during query, key, and value conversion). The following code integrates these computations into a network structure:
from typing import Optional, Type
from mindspore import nn, ops, mint, Parameter, Tensor
class Qwen2Linear(nn.Cell):
def __init__(self, input_size: int, output_size: int, param_dtype: Optional[Type], enable_bias: bool) -> None:
super().__init__()
self.param_dtype = param_dtype
self.input_size = input_size
self.output_size = output_size
self.enable_bias = enable_bias
self.matmul = ops.MatMul(transpose_b=True)
self.weight = Parameter(
mint.zeros(
(self.output_size, self.input_size),
dtype=self.param_dtype
),
requires_grad=False
)
if self.enable_bias:
self.bias_add = ops.Add()
self.bias = Parameter(
mint.zeros(self.output_size, dtype=self.param_dtype)
)
def construct(self, input: Tensor):
origin_shape = input.shape
x = self.matmul(input.view(-1, origin_shape[-1]), self.weight)
if self.enable_bias:
x = self.bias_add(x, self.bias)
return x.view(*origin_shape[:-1], -1)
Because multi-batch computation is required, the input shape may be n times of input_size. To ensure correct computation, the original input shape is saved. After the computation is complete, shape is restored through the view.
Qwen2ForCausalLM
The Qwen2 model is usually encapsulated for specific services. For example, Qwen2ForCausalLM is an encapsulation of Qwen2 for language processing and dialog services.
The Qwen2ForCausalLM class is used to clearly define the main APIs of the model. The following shows the specific implementation:
from glob import glob
from typing import Optional, Type
from mindspore import nn, Tensor, load_checkpoint, load_param_into_net
class Qwen2ForCausalLM(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.model = Qwen2Model(config=config)
self.lm_head = Qwen2Linear(
input_size=config.hidden_size,
output_size=config.vocab_size,
param_dtype=config.param_dtype,
enable_bias=False
)
def load_weight(self, weight_path: str) -> None:
weight_dict = {}
for path in glob(weight_path + "/*.safetensors"):
weight_dict.update(load_checkpoint(path, format="safetensors"))
load_param_into_net(self, weight_dict, strict_load=False)
def construct(self, model_input: Qwen2ModelInput) -> Tensor:
hidden_state = self.model(model_input.input_ids, model_input.positions,
model_input.batch_valid_length, model_input.is_prefill,
model_input.k_caches, model_input.v_caches, model_input.slot_mapping,
model_input.block_tables, model_input.attn_mask, model_input.q_seq_lens)
logits = self.lm_head(hidden_state)[:, -1]
return logits
As shown in the code, Qwen2ForCausalLM has two core APIs:
load_weight: loads weights from the Hugging Face official website model and injects them into the model based on the network script.
construct: performs inference and computing, and calls submodules to complete computing layer by layer. As shown in the construct, the core of the model is the backbone network computing and the linear computing of the last lm_head, which converts the features of hidden_size into the vocabulary probability distribution of vocab_size.
Qwen2Model
Qwen2Model is the main network of the Qwen2 model. It consists of two parts: the embedding layer that converts the input into features and the decoder structure of n Transformer layers.
Embedding
The logic of the embedding layer is simple. It obtains the feature data (which is also a part of the training weight) of hidden_size based on the input word ID through a gather operator. The code is as follows:
from typing import Optional, Type
from mindspore import nn, ops, mint, Parameter, Tensor
class VocabEmbedding(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.num_embeddings = config.vocab_size
self.embedding_dim = config.hidden_size
self.gather = ops.Gather()
self.weight = Parameter(
mint.zeros(
(self.num_embeddings, self.embedding_dim),
dtype=config.param_dtype
),
requires_grad=False
)
def construct(self, input_ids: Tensor):
return self.gather(self.weight, input_ids, 0)
DecoderLayer
DecoderLayer is the core computing unit of the Transformer network. Most of the computing operations are performed at this layer. As shown in the Qwen2 network structure diagram, the network layers include RoPE, Attention, and MLP. To facilitate development, these network layers are constructed first.
RoPE
The rotary position embedding (RoPE) operator is used to enhance the Attention mechanism's capability to perceive the distance between words by adding positional encoding information to the features of the query and key. Due to the features of RoPE, the result can be pre-computed and directly obtained by querying the table, thereby achieving efficient computation. This can be implemented using the gather and the RoPE operators. For details about the calculation method, see the related documents of RoPE.
import numpy as np from typing import Optional, Type from mindspore import nn, ops, mint, Parameter, Tensor class Qwen2RotaryEmbedding(nn.Cell): def __init__(self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, dtype: Optional[Type]) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.dtype = dtype # format 2 is neox style self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) self.gather = ops.Gather() self.freqs_cos, self.freqs_sin = self._compute_cos_sin_cache() def _compute_inv_freq(self) -> Tensor: freqs_base = mint.arange(0, self.rotary_dim, 2).astype(np.float32) freqs = 1.0 / (self.base ** (freqs_base / self.rotary_dim)) return freqs def _compute_cos_sin_cache(self) -> Tuple[Tensor, Tensor]: freqs = self._compute_inv_freq() t = np.arange(0, self.max_position_embeddings, 1).astype(np.float32) freqs = np.outer(t, freqs) emb = np.concatenate((freqs, freqs), axis=1) freqs_cos = np.cos(emb) freqs_sin = np.sin(emb) freqs_cos = Tensor(freqs_cos, dtype=self.dtype) freqs_sin = Tensor(freqs_sin, dtype=self.dtype) return freqs_cos, freqs_sin def construct(self, positions: Tensor, query: Tensor, key: Tensor, batch_valid_length: Tensor, is_prefill: bool): query = query.contiguous() key = key.contiguous() if is_prefill: freqs_cos = self.freqs_cos freqs_sin = self.freqs_sin else: freqs_cos = self.gather(self.freqs_cos, positions.view(-1), 0) freqs_sin = self.gather(self.freqs_sin, positions.view(-1), 0) return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length)
Attention
An attention layer consists of multiple Linear and RoPE operators, and attention score calculation. MindSpore provides two fusion operators, FlashAttention and PagedAttention, to enhance the inference performance of attention score calculation.
However, because these native operators are oriented to multiple scenarios and the input is complex, they are encapsulated here to simplify the usage. For details about the code, see the following:
import numpy as np from typing import Optional, Type from mindspore import nn, ops, mint, Parameter, Tensor class FlashAttention(nn.Cell): def __init__(self, scale: float, num_heads: int) -> None: super().__init__() input_layout = "TH" scale = scale pre_tokens = 2147483647 next_tokens = 2147483647 self.flash_attention = ops.operations.nn_ops.FlashAttentionScore(head_num=num_heads, scale_value=scale, pre_tokens=pre_tokens, next_tokens=next_tokens, input_layout=input_layout) def construct(self, q: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor, batch_valid_length: Tensor) -> Tensor: _, _, _, output = self.flash_attention( q, k, v, None, None, None, attn_mask, None, batch_valid_length, batch_valid_length ) return output class PagedAttention(nn.Cell): def __init__(self, head_num: int, scale: float, num_kv_heads: int) -> None: super().__init__() self.head_num = head_num self.num_kv_heads = num_kv_heads self.paged_attention = ops.auto_generate.PagedAttention( head_num=head_num, scale_value=scale, kv_head_num=num_kv_heads ) def construct(self, q: Tensor, k_cache: Tensor, v_cache: Tensor, block_tables: Tensor, batch_valid_length: Tensor, attn_mask: Tensor, q_seq_lens: Tensor) -> Tensor: output = self.paged_attention(q, k_cache, v_cache, block_tables, batch_valid_length, None, None, attn_mask, q_seq_lens) return output
The code of the attention layer may be implemented by using the constructed network layer. For details about the code, see the following:
import numpy as np from typing import Optional, Type from mindspore import nn, ops, mint, Parameter, Tensor class Qwen2Attention(nn.Cell): def __init__(self, config: Qwen2Config) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim =config.hidden_size // self.num_heads self.q_size = self.head_dim * self.num_heads self.kv_size = self.head_dim * self.num_kv_heads self.scaling = float(self.head_dim ** -0.5) self.rope_theta = int(config.rope_theta) self.param_dtype = config.param_dtype self.max_position = config.max_position_embeddings self.flash_attn = FlashAttention(self.scaling, self.num_heads) self.paged_attn = PagedAttention(self.num_heads, self.scaling, self.num_kv_heads) self.reshape_and_cache = ops.auto_generate.ReshapeAndCache() self.q_proj = Qwen2Linear( input_size=self.hidden_size, output_size=self.q_size, param_dtype=self.param_dtype, enable_bias=True ) self.k_proj = Qwen2Linear( input_size=self.hidden_size, output_size=self.kv_size, param_dtype=self.param_dtype, enable_bias=True ) self.v_proj = Qwen2Linear( input_size=self.hidden_size, output_size=self.kv_size, param_dtype=self.param_dtype, enable_bias=True ) self.o_proj = Qwen2Linear( input_size=self.q_size, output_size=self.hidden_size, param_dtype=self.param_dtype, enable_bias=False ) self.rotary_emb = Qwen2RotaryEmbedding( head_size=self.head_dim, rotary_dim=self.head_dim, max_position_embeddings=self.max_position, base=self.rope_theta, dtype=self.param_dtype ) def construct(self, hidden_state: Tensor, positions: Tensor, batch_valid_length: Tensor, is_prefill: bool, layer_idx: int, k_cache: Tensor, v_cache: Tensor, slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor, q_seq_lens: Tensor) -> Tensor: bs, seq_len, hidden_dim = hidden_state.shape q = self.q_proj(hidden_state).view(-1, self.q_size) k = self.k_proj(hidden_state).view(-1, self.kv_size) v = self.v_proj(hidden_state).view(-1, self.kv_size) q, k = self.rotary_emb( positions, q, k, batch_valid_length, is_prefill ) k = k.contiguous() v = v.contiguous() cache_out = self.reshape_and_cache( k, v, k_cache, v_cache, slot_mapping ) q = ops.depend(q, cache_out) if is_prefill: attn_output = self.flash_attn( q, k, v, attn_mask, batch_valid_length ) else: attn_output = self.paged_attn( q, k_cache, v_cache, block_tables, batch_valid_length, attn_mask, q_seq_lens ) output = self.o_proj(attn_output).view(bs, seq_len, -1) return output
MLP
An MLP layer, consisting of multiple Linear operators and an activation function (usually silu), is responsible for implementing non-linear computation of the network. The MLP layer can project problems to multiple non-linear spaces, thereby enhancing network capabilities. For details about the implementation, see the following code:
import numpy as np from typing import Optional, Type from mindspore import nn, ops, mint, Parameter, Tensor class Qwen2MLP(nn.Cell): def __init__(self, config: Qwen2Config) -> None: super().__init__() self.up_proj = Qwen2Linear( input_size=config.hidden_size, output_size=config.intermediate_size, param_dtype=config.param_dtype, enable_bias=False ) self.gate_proj = Qwen2Linear( input_size=config.hidden_size, output_size=config.intermediate_size, param_dtype=config.param_dtype, enable_bias=False ) self.down_proj = Qwen2Linear( input_size=config.intermediate_size, output_size=config.hidden_size, param_dtype=config.param_dtype, enable_bias=False ) self.act_fn = ops.silu def construct(self, x: Tensor) -> Tensor: output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return output
DecoderLayer may be constructed as follows by referring to the preceding network layer:
from typing import Tuple
from mindspore import nn, Tensor
class Qwen2DecoderLayer(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config=config)
self.mlp = Qwen2MLP(config=config)
self.input_layernorm = RmsNorm(config=config)
self.post_attention_layernorm = RmsNorm(config=config)
def construct(self, hidden_state: Tensor, residual: Tensor, positions: Tensor,
batch_valid_length: Tensor, is_prefill: bool, layer_idx: int,
k_cache: Tensor, v_cache: Tensor, slot_mapping: Tensor,
block_tables: Tensor, attn_mask: Tensor, q_seq_lens: Tensor) -> Tuple[Tensor, Tensor]:
if residual is None:
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
else:
hidden_state, residual = self.input_layernorm(hidden_state, residual)
hidden_state = self.self_attn(hidden_state, positions, batch_valid_length, is_prefill,
layer_idx, k_cache, v_cache, slot_mapping, block_tables,
attn_mask, q_seq_lens)
hidden_state, residual = self.post_attention_layernorm(hidden_state, residual)
hidden_state = self.mlp(hidden_state)
return hidden_state, residual
Model
After the embedding and decoder layers are constructed, you can construct the Model class by referring to the following code:
from mindspore import nn, ops, mint, Parameter, Tensor
class Qwen2Model(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_tokens = VocabEmbedding(config=config)
self.layers = nn.CellList()
for i in range(config.num_hidden_layers):
layer = Qwen2DecoderLayer(config=config)
self.layers.append(layer)
self.norm = RmsNorm(config=config)
def construct(self, input_ids: Tensor, positions: Tensor, batch_valid_length: Tensor,
is_prefill: bool, k_caches: List[Tensor], v_caches: List[Tensor],
slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor,
q_seq_lens: Tensor) -> Tensor:
hidden_state = self.embed_tokens(input_ids)
residual = None
for i in range(self.num_hidden_layers):
layer = self.layers[i]
hidden_state, residual = layer(hidden_state, residual, positions, batch_valid_length,
is_prefill, i, k_caches[i], v_caches[i], slot_mapping,
block_tables, attn_mask, q_seq_lens)
hidden_state, _ = self.norm(hidden_state, residual)
return hidden_state
KVCacheManager
Since KVCache is usually used to optimize LLMs, to use KVCache with FlashAttention and PagedAttention provided by MindSpore, some parameters need to be specified additionally, including:
k_cache & v_cache: The kv_cache object can be considered as a cache table, which is used to store the keys and values in the previous iteration. In the next iteration, these values can be directly read, avoiding repeated computation of the keys and values of the first n words, thereby improving performance.
block_tables & slot_mapping: PagedAttention stores KVCache by block using a mechanism similar to paging, so that the same words can be concentrated in the same block, thereby improving graphics memory utilization.
According to the preceding description, these parameters can be encapsulated in a management class. The code can be referenced as follows:
import math
from collections import deque
from mindspore import nn, ops, mint, Parameter, Tensor, mutable
class CacheManager:
def __init__(self, config: Qwen2Config, block_num: int, block_size: int, batch_size: int) -> None:
self.block_num = block_num
self.block_size = block_size
self.batch_size = batch_size
head_dim = config.hidden_size // config.num_attention_heads
self.k_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)])
self.v_caches = mutable([ops.zeros((block_num, block_size, config.num_key_value_heads, head_dim), dtype=config.param_dtype) for _ in range(config.num_hidden_layers)])
self.block_tables = [[] for _ in range(batch_size)]
self.acc_slot_mapping = [[] for _ in range(batch_size)]
self.free_block_ids = deque(range(block_num))
def step(self, start_pos_idx: int, token_num_per_batch: int) -> Tuple[Tensor, Tensor]:
for i in range(self.batch_size):
block_table = self.block_tables[i]
total_block_num = math.ceil((start_pos_idx + token_num_per_batch) / self.block_size)
now_block_num = len(block_table)
for _ in range(total_block_num - now_block_num):
block_id = self.free_block_ids.popleft()
block_table.append(block_id)
start_slot_id = block_id * self.block_size
self.acc_slot_mapping[i].extend(list(range(start_slot_id, start_slot_id + self.block_size)))
now_block_tables = Tensor(self.block_tables, dtype=dtype.int32)
now_slot_mapping = Tensor([self.acc_slot_mapping[i][start_pos_idx: start_pos_idx + token_num_per_batch]
for i in range(self.batch_size)], dtype=dtype.int32).view(-1)
return now_block_tables, now_slot_mapping
Sampler
After the backbone network is computed, the network output is a vocabulary with shape in the range [batch_size,vocab_size], which indicates the probability distribution of the next word in multiple inference requests in the batch. You need to select a word from the vocabulary as the final result. To simplify the selection and eliminate randomness, you need to select the word with the maximum probability as the output each time, that is, perform argmax computing. The following is a code example:
from mindspore import Tensor
def sample(logits: Tensor) -> Tensor:
next_token = logits.argmax(axis=-1, keepdims=True)
return next_token
Converting Dynamic Graphs to Static Graphs
MindSpore can convert dynamic graphs to static graphs using JIT to improve inference performance. In terms of code implementation, you can use the following simple decorator for conversion:
from mindspore import nn, ops, mint, Parameter, Tensor, jit
class Qwen2Model(nn.Cell):
def __init__(self, config: Qwen2Config) -> None:
super().__init__()
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_tokens = VocabEmbedding(config=config)
self.layers = nn.CellList()
for i in range(config.num_hidden_layers):
layer = Qwen2DecoderLayer(config=config)
self.layers.append(layer)
self.norm = RmsNorm(config=config)
@jit(jit_level="O0", infer_boost="on")
def construct(self, input_ids: Tensor, positions: Tensor, batch_valid_length: Tensor,
is_prefill: bool, k_caches: List[Tensor], v_caches: List[Tensor],
slot_mapping: Tensor, block_tables: Tensor, attn_mask: Tensor,
q_seq_lens: Tensor) -> Tensor:
hidden_state = self.embed_tokens(input_ids)
residual = None
for i in range(self.num_hidden_layers):
layer = self.layers[i]
hidden_state, residual = layer(hidden_state, residual, positions, batch_valid_length,
is_prefill, i, k_caches[i], v_caches[i], slot_mapping,
block_tables, attn_mask, q_seq_lens)
hidden_state, _ = self.norm(hidden_state, residual)
return hidden_state
Add the mindspore.jit decorator to the construct method of nn.Cell to execute the computation of the cell in static graph mode. The parameters are described as follows:
jit_level: specifies the compilation level. Currently, MindSpore inference supports O0 and O1 levels (some operator fusion optimization is involved).
infer_boost: enables inference acceleration optimization. After this option is enabled, some scheduling optimization and stream optimization are performed during runtime to improve inference performance.
In addition, due to the limitations of the static graph mode of MindSpore, dynamic-to-static conversion may fail in some scenarios. The following lists some common causes:
setattrs usage: The setattrs syntax of Python is not supported during MindSpore graph capture. Therefore, parameters cannot be encapsulated using an encapsulation class. For example, Qwen2ModelInput in the preceding example cannot be directly passed to Qwen2Model whose graph is converted to a static graph. Otherwise, the static graph execution fails.
List value: If there are list parameters when the graph is converted to a static graph, the parameters must be wrapped by mutable to ensure that MindSpore can correctly process the parameters, for example, k_caches and v_caches in the preceding example. Otherwise, the fallback to Python is triggered, which affects the inference performance. In some scenarios, the computation may fail.
Graph input name: If the PagedAttention operator of MindSpore is used, the two graph inputs must be named batch_valid_length and q_seq_lens. Otherwise, the PagedAttention operator fails to be initialized.
If you plan to use static graph inference in the future when developing models with MindSpore, you are advised to pay attention to the preceding limitations during dynamic graph development and debugging to avoid extra costs in subsequent migration and debugging.