mindspore.ops.nsa_compress_attention
- mindspore.ops.nsa_compress_attention(query, key, value, scale_value, head_num, compress_block_size, compress_stride, select_block_size, select_block_count, *, topk_mask=None, atten_mask=None, actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen)[源代码]
使用NSA Compress Attention算法进行注意力压缩计算(Ascend)。
该算子通过以下步骤计算压缩注意力:
计算注意力分数并应用掩码:
\[QK = scale \cdot query \cdot key^T\]应用注意力掩码(如果提供):
\[QK = \text{atten_mask}(QK, atten\_mask)\]计算压缩注意力分数:
\[P_{cmp} = \text{Softmax}(QK)\]计算注意力输出:
\[\text{attentionOut} = P_{cmp} \cdot value\]计算重要性分数(Importance Score):
\[P_{slc}[j] = \sum_{m=0}^{l'/d-1} \sum_{n=0}^{l/d-1} P_{cmp}[l'/d \cdot j - m - n]\]其中 \(l'\) 为 select KV 序列长度,\(l\) 为 compress KV 序列长度,\(d\) 为 compress_stride。
聚合多头的重要性分数:
\[P_{slc}' = \sum_{h=1}^{H} P_{slc}^h\]其中 \(H\) 为头数。
应用 TopK 掩码:
\[P_{slc}' = \text{topk_mask}(P_{slc}')\]选择 TopK 索引:
\[\text{topkIndices} = \text{topk}(P_{slc}')\]
说明
Ascend 平台下内部布局固定为 "TND"。
actual_seq_qlen 、 actual_cmp_seq_kvlen 、 actual_sel_seq_kvlen 采用前缀和模式。
compress_block_size 必须是16的整数倍,支持范围:16到128。
compress_stride 必须是16的整数倍,支持范围:16到64。
select_block_size 必须是16的整数倍,支持范围:16到128。
select_block_count 支持范围:1到32。
compress_block_size >= compress_stride
select_block_size >= compress_block_size
select_block_size % compress_stride == 0
头维度约束:query和key的头维度 D1 必须相同,且 D1 >= D2 (key的头维度大于等于value的头维度)。
头数约束: N1 >= N2 (query的头数大于等于key的头数)且 N1 % N2 == 0 (query的头数必须是key头数的整数倍)。
D1 和 D2 必须是16的整数倍。
- 参数:
query (Tensor) - 查询张量,形状为 (T1, N1, D1) ,其中 T1 为查询序列长度,N1 为查询头数,D1 为头维度。dtype类型为 float16 或 bfloat16。必选参数。
key (Tensor) - 键张量,形状为 (T2, N2, D1) ,其中 T2 为键序列长度,N2 为键头数,D1 为头维度(与 query 相同)。dtype类型为 float16 或 bfloat16。必选参数。
value (Tensor) - 值张量,形状为 (T2, N2, D2) ,其中 T2 为值序列长度,N2 为值头数(与 key 相同),D2 为值头维度。dtype类型为 float16 或 bfloat16。必选参数。
scale_value (float) - 注意力分数的缩放因子。必选参数。
head_num (int) - 注意力头数,应等于 query 的头数 N1 。必选参数。
compress_block_size (int) - 压缩滑窗大小。必选参数。
compress_stride (int) - 相邻滑窗间距。必选参数。
select_block_size (int) - 选择块大小。必选参数。
select_block_count (int) - 选择块个数。必选参数。
- 关键字参数:
topk_mask (Tensor, 可选) - TopK掩码张量,形状为 (S1, S2) ,其中 S1 为查询序列长度,S2 为select KV序列长度。dtype为 bool。默认值:
None。atten_mask (Tensor, 可选) - 注意力掩码张量,形状为 (S1, S2) ,其中 S1 为查询序列长度,S2 为compress KV序列长度。dtype为 bool。默认值:
None。actual_seq_qlen (Union[tuple[int], list[int]]) - 批次query序列长度(前缀和)。
actual_cmp_seq_kvlen (Union[tuple[int], list[int]]) - 批次compress KV序列长度(前缀和)。
actual_sel_seq_kvlen (Union[tuple[int], list[int]]) - 批次select KV序列长度(前缀和)。
- 返回:
tuple,包含四个张量。
attention_out (Tensor) - 注意力输出张量,形状为 (T1, N1, D2) 。
topk_indices_out (Tensor) - TopK索引张量,形状为 (T1, N2, select_block_count) ,其中 T1 为查询序列长度,N2 为键头数。
softmax_max_out (Tensor) - Softmax计算的Max中间结果,形状为 (T1, N1, 8) 。
softmax_sum_out (Tensor) - Softmax计算的Sum中间结果,形状为 (T1, N1, 8) 。
- 异常:
TypeError - 如果输入参数类型不正确。
ValueError - 如果输入张量形状或参数不满足约束。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> from mindspore import Tensor, ops >>> T, N, D = 256, 8, 128 >>> query = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> key = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> value = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> topk_mask = Tensor(np.ones((T, T), dtype=np.bool_)) >>> atten_mask = Tensor(np.ones((T, T), dtype=np.bool_)) >>> actual_seq_qlen = [T] >>> actual_cmp_seq_kvlen = [T] >>> actual_sel_seq_kvlen = [T] >>> compress_block_size = 32 >>> compress_stride = 16 >>> select_block_size = 64 >>> select_block_count = 16 >>> scale_value = 1.0 / (D ** 0.5) >>> head_num = N >>> attention_out, topk_indices_out, softmax_max_out, softmax_sum_out = ops.nsa_compress_attention( ... query, key, value, scale_value, head_num, compress_block_size, compress_stride, ... select_block_size, select_block_count, topk_mask, atten_mask, ... actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen) >>> print(attention_out.shape) (256, 8, 128)