mindspore.ops.nsa_select_attention

查看源文件
mindspore.ops.nsa_select_attention(query, key, value, topk_indices, scale_value, head_num, select_block_size, select_block_count, *, atten_mask=None, actual_seq_qlen, actual_seq_kvlen) Tuple[Tensor, Tensor, Tensor][源代码]

本算子用于在训练场景中计算原生稀疏注意力(Native Sparse Attention)算法的选择性注意力机制。

该算子实现了原生稀疏注意力中的选择性注意力计算,通过依据 topk_indices 选择特定的注意力块以高效计算注意力权重。

警告

  • 输入 querykeyvalue 的布局固定为 TND

  • 只支持 Atlas A2 训练系列产品。

  • topk_indices 若存在取值越界,可能导致未定义行为。

参数:
  • query (Tensor) - 输入 query 张量,shape 为 \((T_1, N_1, D_1)\),其中 \(T_1\) 为序列长度, \(N_1\) 为注意力头数, \(D_1\) 为单头维度。支持数据类型:mindspore.bfloat16mindspore.float16 。支持非连续 Tensor,不支持空 Tensor。

  • key (Tensor) - 输入 key 张量,shape 为 \((T_2, N_2, D_1)\),其中 \(T_2\)key 序列长度, \(N_2\)key 的头数, \(D_1\) 为单头维度(与 query 相同)。支持数据类型:mindspore.bfloat16mindspore.float16 。支持非连续 Tensor,不支持空 Tensor。

  • value (Tensor) - 输入 value 张量,shape 为 \((T_2, N_2, D_2)\),其中 \(T_2\)value 序列长度, \(N_2\)value 的头数, \(D_2\)value 的单头维度。支持数据类型:mindspore.bfloat16mindspore.float16 。支持非连续 Tensor,不支持空 Tensor。

  • topk_indices (Tensor) - 索引张量,shape 为 \((T_1, N_2, select\_block\_count)\),用于指定选择哪些注意力块。支持数据类型:mindspore.int32 。支持非连续 Tensor,不支持空 Tensor。对于每个 batch,topk_indices 的每个元素必须满足 \(0 \leq index \leq S_2 / 64\),其中 \(S_2\) 为该 batch 的有效 KV 序列长度,64select_block_size

  • scale_value (float) - 作用于注意力分数的缩放因子,通常设为 \(D^{-0.5}\),其中 \(D\) 为头维度。

  • head_num (int) - 每设备的注意力头数量,应等于 query\(N_1\) 轴长度。

  • select_block_size (int) - 选择窗口的大小。目前仅支持 64

  • select_block_count (int) - 选择窗口的数量。当 select_block_size64 时,该参数值应为 16

关键字参数:
  • atten_mask (Tensor,可选) - 注意力掩码张量。目前不支持。默认: None

  • actual_seq_qlen (list[int],可选) - 每个 batch 中 query 对应的大小(前缀和模式),必须为非递减整数序列,最后一个值等于 \(T_1\)

  • actual_seq_kvlen (list[int],可选) - 每个 batch 中 keyvalue 对应的大小(前缀和模式),必须为非递减整数序列,最后一个值等于 \(T_2\)

返回:

一个Tensor元组,包含 attention_outsoftmax_maxsoftmax_sum

  • attention_out 是注意力的输出结果。

  • softmax_max 是Softmax计算的中间最大值结果,用于反向计算。

  • softmax_sum 是Softmax计算的中间求和结果,用于反向计算。

异常:
  • TypeError - querykeyvaluetopk_indices 不是 Tensor。

  • TypeError - scale_value 不是 float。

  • TypeError - head_numselect_block_sizeselect_block_count 不是 int。

  • TypeError - actual_seq_qlenactual_seq_kvlen 在提供时不是 int 列表。

  • RuntimeError - querykeyvalue 的数据类型不一致。

  • RuntimeError - querykeyvalue 的 batch 大小不相等。

  • RuntimeError - head_numquery 的头维度不匹配。

  • RuntimeError - topk_indices 存在超出有效范围 \(0 \leq index \leq S_2 / 64\) 的取值。

  • RuntimeError - 维度约束不满足:\(D_q == D_k\)\(D_k >= D_v\)

支持平台:

Ascend

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> # Create input tensors
>>> query = Tensor(np.random.randn(256, 16, 192).astype(np.float16))
>>> key = Tensor(np.random.randn(1024, 4, 192).astype(np.float16))
>>> value = Tensor(np.random.randn(1024, 4, 128).astype(np.float16))
>>> topk_indices = Tensor(np.random.randint(0, 16, size=(256, 4, 16)).astype(np.int32))
>>> scale_value = 1.0 / (192 ** 0.5)  # Typical scaling factor
>>> head_num = 16
>>> select_block_size = 64
>>> select_block_count = 16
>>> actual_seq_qlen = [128, 256]  # Cumulative sequence lengths for query
>>> actual_seq_kvlen = [512, 1024]  # Cumulative sequence lengths for key/value
>>> # Compute native sparse attention
>>> attention_out, softmax_max, softmax_sum = ops.nsa_select_attention(
...     query, key, value, topk_indices, scale_value, head_num,
...     select_block_size, select_block_count,
...     actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen
... )
>>> print(attention_out.shape)
(256, 16, 128)
>>> print(softmax_max.shape)
(256, 16, 8)
>>> print(softmax_sum.shape) 
(256, 16, 8)