mindspore.ops.lightning_indexer
- mindspore.ops.lightning_indexer(query, key, weights, *, actual_seq_lengths_query=None, actual_seq_lengths_key=None, block_table=None, layout_query='BSND', layout_key='BSND', sparse_count=2048, sparse_mode=3, pre_tokens=2 ^ 63 - 1, next_tokens=2 ^ 63 - 1, return_value=False) Tuple[Tensor, Tensor][source]
Computes the Top-k sparse indices for each query token based on the DeepSeek Sparse Attention (DSA) algorithm.
This operator is a core component of the DSA mechanism introduced in DeepSeek-V3.2, which enables efficient long-context modeling by selectively attending to the most relevant tokens rather than processing the entire input sequence.
\[Indices = \text{Top-}k\left\{[1]_{1\times g} @ \left[(W @ [1]_{1\times S_k}) \odot \text{ReLU}\left(Q_{index} @ K_{index}^T\right)\right]\right\}\]where \(Q_{index} \in \mathbb{R}^{g \times d}\) is the Index Query, \(K_{index} \in \mathbb{R}^{S_k \times d}\) is the context Index Key, \(W \in \mathbb{R}^{g \times 1}\) is the head weights, \(g\) is the GQA group size, \(d\) is the head dimension, and \(S_k\) is the context length.
B: Batch size.
S1: Sequence length of query.
S2: Sequence length of key.
T1: Total tokens of query in TND layout (sum of all batch sequence lengths).
T2: Total tokens of key in TND layout.
N1: Number of heads of query. Supports 64 (only 64 when layout_query is TND).
N2: Number of heads of key. Only supports 1.
D: Head dimension. Only supports 128.
Warning
Only supports Atlas A2 inference series products (Ascend platform).
Note
The data types of query, key, and weights must be consistent.
In non-PageAttention scenarios, layout_key should be consistent with layout_query.
- Parameters
query (Tensor) – The query tensor with shape \((B, S1, N1, D)\) or \((T1, N1, D)\). Supported data types:
mindspore.bfloat16,mindspore.float16.key (Tensor) – The key tensor with shape \((B, S2, N2, D)\), \((T2, N2, D)\), or \((block\_count, block\_size, N2, D)\) for PageAttention. Supported data types:
mindspore.bfloat16,mindspore.float16.weights (Tensor) – The head weights tensor with shape \((B, S1, N1)\) or \((T1, N1)\) for weighted aggregation. Dtype must be same as that of query.
- Keyword Arguments
actual_seq_lengths_query (Tensor, optional) – The effective token count for each batch in query with shape \((B,)\). Supported data type:
mindspore.int32. When layout_query is"TND", this parameter is required and represents the prefix sum (cumulative sum) of sequence lengths. Default:None.actual_seq_lengths_key (Tensor, optional) – The effective token count for each batch in key with shape \((B,)\). Supported data type:
mindspore.int32. When layout_key is"TND"or"PA_BSND", this parameter is required. Default:None.block_table (Tensor, optional) – The block mapping table for PageAttention KV storage with shape \((B, max\_blocks)\). Supported data type:
mindspore.int32. Required when layout_key is"PA_BSND". Default:None.layout_query (str, optional) – Specifies the data layout of input query. Supports
"BSND"and"TND". Default:"BSND".layout_key (str, optional) – Specifies the data layout of input key. Supports
"BSND","TND", and"PA_BSND". Default:"BSND".sparse_count (int, optional) – The number of Top-k tokens to retain in the sparse selection. Value range is [1, 2048]. Default:
2048.sparse_mode (int, optional) –
Specifies the sparse mode. Default:
3.0: defaultMask mode.
3: rightDownCausal mode, corresponds to the lower triangle scene divided by the lower right vertex.
pre_tokens (int, optional) – Reserved parameter for sparse computation, represents how many tokens are counted forward. Only supports the default value. Default:
2^63-1.next_tokens (int, optional) – Reserved parameter for sparse computation, represents how many tokens are counted backward. Only supports the default value. Default:
2^63-1.return_value (bool, optional) – Whether to output valid sparse_values. only supported
Truewhen layout_key is not"PA_BSND". Default:False.
- Returns
A tuple of tensors containing sparse_indices and sparse_values.
sparse_indices is the indices of Top-k tokens with dtype
mindspore.int32and shape \((B, S1, N2, sparse\_count)\) or \((T1, N2, sparse\_count)\).sparse_values is the corresponding values of Top-k tokens with same dtype as query and same shape as sparse_indices. Only valid when return_value is
True.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, ops >>> # BSND layout with PA_BSND key layout example >>> b, s1, s2, n1, n2, d = 1, 1, 8192, 64, 1, 128 >>> block_size = 256 >>> query = Tensor(np.random.randn(b, s1, n1, d).astype(np.float16)) >>> key = Tensor(np.random.randn(b * (s2 // block_size), block_size, n2, d).astype(np.float16)) >>> weights = Tensor(np.random.randn(b, s1, n1).astype(np.float16)) >>> actual_seq_lengths_query = Tensor(np.array([s1]).astype(np.int32)) >>> actual_seq_lengths_key = Tensor(np.array([s2]).astype(np.int32)) >>> block_table = Tensor(np.arange(b * s2 // block_size).reshape(b, -1).astype(np.int32)) >>> sparse_indices, sparse_values = ops.lightning_indexer( ... query, key, weights, ... actual_seq_lengths_query=actual_seq_lengths_query, ... actual_seq_lengths_key=actual_seq_lengths_key, ... block_table=block_table, ... layout_query='BSND', ... layout_key='PA_BSND', ... sparse_count=2048, ... sparse_mode=3 ... ) >>> print(sparse_indices.shape) (1, 1, 1, 2048)