mindspore.ops.nsa_compress_attention

查看源文件
mindspore.ops.nsa_compress_attention(query, key, value, scale_value, head_num, compress_block_size, compress_stride, select_block_size, select_block_count, *, topk_mask=None, atten_mask=None, actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen)[源代码]

使用NSA Compress Attention算法进行注意力压缩计算(Ascend)。

该算子通过以下步骤计算压缩注意力:

  1. 计算注意力分数并应用掩码:

    \[QK = scale \cdot query \cdot key^T\]

    应用注意力掩码(如果提供):

    \[QK = \text{atten_mask}(QK, atten\_mask)\]
  2. 计算压缩注意力分数:

    \[P_{cmp} = \text{Softmax}(QK)\]
  3. 计算注意力输出:

    \[\text{attentionOut} = P_{cmp} \cdot value\]
  4. 计算重要性分数(Importance Score):

    \[P_{slc}[j] = \sum_{m=0}^{l'/d-1} \sum_{n=0}^{l/d-1} P_{cmp}[l'/d \cdot j - m - n]\]

    其中 \(l'\) 为 select KV 序列长度,\(l\) 为 compress KV 序列长度,\(d\) 为 compress_stride。

  5. 聚合多头的重要性分数:

    \[P_{slc}' = \sum_{h=1}^{H} P_{slc}^h\]

    其中 \(H\) 为头数。

  6. 应用 TopK 掩码:

    \[P_{slc}' = \text{topk_mask}(P_{slc}')\]
  7. 选择 TopK 索引:

    \[\text{topkIndices} = \text{topk}(P_{slc}')\]

说明

  • Ascend 平台下内部布局固定为 "TND"。

  • actual_seq_qlenactual_cmp_seq_kvlenactual_sel_seq_kvlen 采用前缀和模式。

  • compress_block_size 必须是16的整数倍,支持范围:16到128。

  • compress_stride 必须是16的整数倍,支持范围:16到64。

  • select_block_size 必须是16的整数倍,支持范围:16到128。

  • select_block_count 支持范围:1到32。

  • compress_block_size >= compress_stride

  • select_block_size >= compress_block_size

  • select_block_size % compress_stride == 0

  • 头维度约束:query和key的头维度 D1 必须相同,且 D1 >= D2 (key的头维度大于等于value的头维度)。

  • 头数约束: N1 >= N2 (query的头数大于等于key的头数)且 N1 % N2 == 0 (query的头数必须是key头数的整数倍)。

  • D1D2 必须是16的整数倍。

参数:
  • query (Tensor) - 查询张量,形状为 (T1, N1, D1) ,其中 T1 为查询序列长度,N1 为查询头数,D1 为头维度。dtype类型为 float16 或 bfloat16。必选参数。

  • key (Tensor) - 键张量,形状为 (T2, N2, D1) ,其中 T2 为键序列长度,N2 为键头数,D1 为头维度(与 query 相同)。dtype类型为 float16 或 bfloat16。必选参数。

  • value (Tensor) - 值张量,形状为 (T2, N2, D2) ,其中 T2 为值序列长度,N2 为值头数(与 key 相同),D2 为值头维度。dtype类型为 float16 或 bfloat16。必选参数。

  • scale_value (float) - 注意力分数的缩放因子。必选参数。

  • head_num (int) - 注意力头数,应等于 query 的头数 N1 。必选参数。

  • compress_block_size (int) - 压缩滑窗大小。必选参数。

  • compress_stride (int) - 相邻滑窗间距。必选参数。

  • select_block_size (int) - 选择块大小。必选参数。

  • select_block_count (int) - 选择块个数。必选参数。

关键字参数:
  • topk_mask (Tensor, 可选) - TopK掩码张量,形状为 (S1, S2) ,其中 S1 为查询序列长度,S2 为select KV序列长度。dtype为 bool。默认值: None

  • atten_mask (Tensor, 可选) - 注意力掩码张量,形状为 (S1, S2) ,其中 S1 为查询序列长度,S2 为compress KV序列长度。dtype为 bool。默认值: None

  • actual_seq_qlen (Union[tuple[int], list[int]]) - 批次query序列长度(前缀和)。

  • actual_cmp_seq_kvlen (Union[tuple[int], list[int]]) - 批次compress KV序列长度(前缀和)。

  • actual_sel_seq_kvlen (Union[tuple[int], list[int]]) - 批次select KV序列长度(前缀和)。

返回:

tuple,包含四个张量。

  • attention_out (Tensor) - 注意力输出张量,形状为 (T1, N1, D2)

  • topk_indices_out (Tensor) - TopK索引张量,形状为 (T1, N2, select_block_count) ,其中 T1 为查询序列长度,N2 为键头数。

  • softmax_max_out (Tensor) - Softmax计算的Max中间结果,形状为 (T1, N1, 8)

  • softmax_sum_out (Tensor) - Softmax计算的Sum中间结果,形状为 (T1, N1, 8)

异常:
  • TypeError - 如果输入参数类型不正确。

  • ValueError - 如果输入张量形状或参数不满足约束。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> T, N, D = 256, 8, 128
>>> query = Tensor(np.random.randn(T, N, D).astype(np.float16))
>>> key = Tensor(np.random.randn(T, N, D).astype(np.float16))
>>> value = Tensor(np.random.randn(T, N, D).astype(np.float16))
>>> topk_mask = Tensor(np.ones((T, T), dtype=np.bool_))
>>> atten_mask = Tensor(np.ones((T, T), dtype=np.bool_))
>>> actual_seq_qlen = [T]
>>> actual_cmp_seq_kvlen = [T]
>>> actual_sel_seq_kvlen = [T]
>>> compress_block_size = 32
>>> compress_stride = 16
>>> select_block_size = 64
>>> select_block_count = 16
>>> scale_value = 1.0 / (D ** 0.5)
>>> head_num = N
>>> attention_out, topk_indices_out, softmax_max_out, softmax_sum_out = ops.nsa_compress_attention(
...     query, key, value, scale_value, head_num, compress_block_size, compress_stride, 
...     select_block_size, select_block_count, topk_mask, atten_mask, 
...     actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen)
>>> print(attention_out.shape)
(256, 8, 128)