mindspore.ops.lightning_indexer

View Source On AtomGit
mindspore.ops.lightning_indexer(query, key, weights, *, actual_seq_lengths_query=None, actual_seq_lengths_key=None, block_table=None, layout_query='BSND', layout_key='BSND', sparse_count=2048, sparse_mode=3, pre_tokens=2 ^ 63 - 1, next_tokens=2 ^ 63 - 1, return_value=False) Tuple[Tensor, Tensor][source]

Computes the Top-k sparse indices for each query token based on the DeepSeek Sparse Attention (DSA) algorithm.

This operator is a core component of the DSA mechanism introduced in DeepSeek-V3.2, which enables efficient long-context modeling by selectively attending to the most relevant tokens rather than processing the entire input sequence.

\[Indices = \text{Top-}k\left\{[1]_{1\times g} @ \left[(W @ [1]_{1\times S_k}) \odot \text{ReLU}\left(Q_{index} @ K_{index}^T\right)\right]\right\}\]

where \(Q_{index} \in \mathbb{R}^{g \times d}\) is the Index Query, \(K_{index} \in \mathbb{R}^{S_k \times d}\) is the context Index Key, \(W \in \mathbb{R}^{g \times 1}\) is the head weights, \(g\) is the GQA group size, \(d\) is the head dimension, and \(S_k\) is the context length.

  • B: Batch size.

  • S1: Sequence length of query.

  • S2: Sequence length of key.

  • T1: Total tokens of query in TND layout (sum of all batch sequence lengths).

  • T2: Total tokens of key in TND layout.

  • N1: Number of heads of query. Supports 64 (only 64 when layout_query is TND).

  • N2: Number of heads of key. Only supports 1.

  • D: Head dimension. Only supports 128.

Warning

Only supports Atlas A2 inference series products (Ascend platform).

Note

  • The data types of query, key, and weights must be consistent.

  • In non-PageAttention scenarios, layout_key should be consistent with layout_query.

Parameters
  • query (Tensor) – The query tensor with shape \((B, S1, N1, D)\) or \((T1, N1, D)\). Supported data types: mindspore.bfloat16 , mindspore.float16 .

  • key (Tensor) – The key tensor with shape \((B, S2, N2, D)\), \((T2, N2, D)\), or \((block\_count, block\_size, N2, D)\) for PageAttention. Supported data types: mindspore.bfloat16 , mindspore.float16 .

  • weights (Tensor) – The head weights tensor with shape \((B, S1, N1)\) or \((T1, N1)\) for weighted aggregation. Dtype must be same as that of query.

Keyword Arguments
  • actual_seq_lengths_query (Tensor, optional) – The effective token count for each batch in query with shape \((B,)\). Supported data type: mindspore.int32 . When layout_query is "TND" , this parameter is required and represents the prefix sum (cumulative sum) of sequence lengths. Default: None .

  • actual_seq_lengths_key (Tensor, optional) – The effective token count for each batch in key with shape \((B,)\). Supported data type: mindspore.int32 . When layout_key is "TND" or "PA_BSND" , this parameter is required. Default: None .

  • block_table (Tensor, optional) – The block mapping table for PageAttention KV storage with shape \((B, max\_blocks)\). Supported data type: mindspore.int32 . Required when layout_key is "PA_BSND" . Default: None .

  • layout_query (str, optional) – Specifies the data layout of input query. Supports "BSND" and "TND" . Default: "BSND" .

  • layout_key (str, optional) – Specifies the data layout of input key. Supports "BSND" , "TND" , and "PA_BSND" . Default: "BSND" .

  • sparse_count (int, optional) – The number of Top-k tokens to retain in the sparse selection. Value range is [1, 2048]. Default: 2048 .

  • sparse_mode (int, optional) –

    Specifies the sparse mode. Default: 3 .

    • 0: defaultMask mode.

    • 3: rightDownCausal mode, corresponds to the lower triangle scene divided by the lower right vertex.

  • pre_tokens (int, optional) – Reserved parameter for sparse computation, represents how many tokens are counted forward. Only supports the default value. Default: 2^63-1 .

  • next_tokens (int, optional) – Reserved parameter for sparse computation, represents how many tokens are counted backward. Only supports the default value. Default: 2^63-1 .

  • return_value (bool, optional) – Whether to output valid sparse_values. only supported True when layout_key is not "PA_BSND" . Default: False .

Returns

A tuple of tensors containing sparse_indices and sparse_values.

  • sparse_indices is the indices of Top-k tokens with dtype mindspore.int32 and shape \((B, S1, N2, sparse\_count)\) or \((T1, N2, sparse\_count)\).

  • sparse_values is the corresponding values of Top-k tokens with same dtype as query and same shape as sparse_indices. Only valid when return_value is True .

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> # BSND layout with PA_BSND key layout example
>>> b, s1, s2, n1, n2, d = 1, 1, 8192, 64, 1, 128
>>> block_size = 256
>>> query = Tensor(np.random.randn(b, s1, n1, d).astype(np.float16))
>>> key = Tensor(np.random.randn(b * (s2 // block_size), block_size, n2, d).astype(np.float16))
>>> weights = Tensor(np.random.randn(b, s1, n1).astype(np.float16))
>>> actual_seq_lengths_query = Tensor(np.array([s1]).astype(np.int32))
>>> actual_seq_lengths_key = Tensor(np.array([s2]).astype(np.int32))
>>> block_table = Tensor(np.arange(b * s2 // block_size).reshape(b, -1).astype(np.int32))
>>> sparse_indices, sparse_values = ops.lightning_indexer(
...     query, key, weights,
...     actual_seq_lengths_query=actual_seq_lengths_query,
...     actual_seq_lengths_key=actual_seq_lengths_key,
...     block_table=block_table,
...     layout_query='BSND',
...     layout_key='PA_BSND',
...     sparse_count=2048,
...     sparse_mode=3
... )
>>> print(sparse_indices.shape)
(1, 1, 1, 2048)