mindspore.ops.nsa_select_attention

View Source On Gitee
mindspore.ops.nsa_select_attention(query, key, value, topk_indices, scale_value, head_num, select_block_size, select_block_count, *, atten_mask=None, actual_seq_qlen, actual_seq_kvlen) Tuple[Tensor, Tensor, Tensor][source]

Computes Native Sparse Attention algorithm for training scenarios with selective attention mechanism.

This operation implements the selective attention computation in Native Sparse Attention algorithm, which efficiently computes attention weights by selecting specific attention blocks based on topk_indices.

Warning

  • Layout of query, key, and value is fixed to TND .

  • It is only supported on Atlas A2 Training Series Products.

  • Out-of-range values in topk_indices may lead to undefined behavior.

Parameters
  • query (Tensor) – The query tensor with shape \((T_1, N_1, D_1)\), where \(T_1\) is the sequence length, \(N_1\) is the number of attention heads, and \(D_1\) is the head dimension. Supported data types: mindspore.bfloat16 , mindspore.float16 . Supports non-contiguous Tensor but not empty Tensor.

  • key (Tensor) – The key tensor with shape \((T_2, N_2, D_1)\), where \(T_2\) is the key sequence length, \(N_2\) is the number of key heads, and \(D_1\) is the head dimension (same as query). Supported data types: mindspore.bfloat16 , mindspore.float16 . Supports non-contiguous Tensor but not empty Tensor.

  • value (Tensor) – The value tensor with shape \((T_2, N_2, D_2)\), where \(T_2\) is the value sequence length, \(N_2\) is the number of value heads, and \(D_2\) is the value head dimension. Supported data types: mindspore.bfloat16 , mindspore.float16 . Supports non-contiguous Tensor but not empty Tensor.

  • topk_indices (Tensor) – The indices tensor with shape \((T_1, N_2, select\_block\_count)\) that specifies which attention blocks to select. Supported data types: mindspore.int32 . Supports non-contiguous Tensor but not empty Tensor. For each batch, every element of topk_indices must satisfy \(0 \leq index \leq S_2 / 64\), where \(S_2\) is the valid KV sequence length of the batch and 64 is the select_block_size.

  • scale_value (float) – The scaling factor applied to attention scores, typically set to \(D^{-0.5}\) where \(D\) is the head dimension.

  • head_num (int) – The number of attention heads per device, which should equal the \(N_1\) axis length of query.

  • select_block_size (int) – The size of each selection window. Currently only supports 64 .

  • select_block_count (int) – The number of selection windows. When select_block_size is 64 , this should be 16 .

Keyword Arguments
  • atten_mask (Tensor, optional) – The attention mask tensor. Currently not supported. Default: None .

  • actual_seq_qlen (list[int]) – Size of query corresponding to each batch, given in cumulative (prefix-sum) mode, sequence of non-decreasing integers with the last value equal to \(T_1\) .

  • actual_seq_kvlen (list[int]) – Size of key and value corresponding to each batch, given in cumulative (prefix-sum) mode, sequence of non-decreasing integers with the last value equal to \(T_2\) .

Returns

A tuple of tensors containing attention_out, softmax_max and softmax_sum.

  • attention_out is the output of attention.

  • softmax_max is the max intermediate result calculated by Softmax, used for grad calculation.

  • softmax_sum is the sum intermediate result calculated by Softmax, used for grad calculation.

Raises
  • TypeError – If query, key, value, or topk_indices is not a Tensor.

  • TypeError – If scale_value is not a float.

  • TypeError – If head_num, select_block_size, or select_block_count is not an int.

  • TypeError – If actual_seq_qlen or actual_seq_kvlen is not a list of int when provided.

  • RuntimeError – If the data types of query, key, and value are inconsistent.

  • RuntimeError – If the batch sizes of query, key, and value are not equal.

  • RuntimeError – If head_num does not match the head dimension of query.

  • RuntimeError – If topk_indices contains values outside the valid range \(0 \leq index \leq S_2 / 64\).

  • RuntimeError – If dimension constraints are not satisfied: \(D_q == D_k\) and \(D_k >= D_v\).

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> # Create input tensors
>>> query = Tensor(np.random.randn(256, 16, 192).astype(np.float16))
>>> key = Tensor(np.random.randn(1024, 4, 192).astype(np.float16))
>>> value = Tensor(np.random.randn(1024, 4, 128).astype(np.float16))
>>> topk_indices = Tensor(np.random.randint(0, 16, size=(256, 4, 16)).astype(np.int32))
>>> scale_value = 1.0 / (192 ** 0.5)  # Typical scaling factor
>>> head_num = 16
>>> select_block_size = 64
>>> select_block_count = 16
>>> actual_seq_qlen = [128, 256]  # Cumulative sequence lengths for query
>>> actual_seq_kvlen = [512, 1024]  # Cumulative sequence lengths for key/value
>>> # Compute native sparse attention
>>> attention_out, softmax_max, softmax_sum = ops.nsa_select_attention(
...     query, key, value, topk_indices, scale_value, head_num,
...     select_block_size, select_block_count,
...     actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen
... )
>>> print(attention_out.shape)
(256, 16, 128)
>>> print(softmax_max.shape)
(256, 16, 8)
>>> print(softmax_sum.shape)
(256, 16, 8)