mindspore.ops.sparse_flash_attention
- mindspore.ops.sparse_flash_attention(query, key, value, sparse_indices, scale_value, *, block_table=None, actual_seq_lengths_query=None, actual_seq_lengths_kv=None, query_rope=None, key_rope=None, sparse_block_size=1, layout_query='BSND', layout_kv='BSND', sparse_mode=3, pre_tokens=2 ^ 63 - 1, next_tokens=2 ^ 63 - 1, attention_mode=0, return_softmax_lse=False) Tuple[Tensor, Tensor, Tensor][源代码]
针对大序列长度推理场景的稀疏注意力计算模块。该模块通过"只计算关键部分"大幅减少计算量,基于稀疏索引选取重要的 Key 和 Value 进行注意力计算。 以下为公式与参数 shape 中维度的含义:
B:batch维。
S1: query 即公式里的 Q 的序列长度。
S2: key 和 value 即公式里的 K 和 V 的序列长度。
N1: query 的多头数量,支持 1/2/4/8/16/32/64/128。
N2: key 和 value 的多头数量,仅支持 1。
D: query / key / value 的头维度,固定为 512。
D_rope: query_rope / key_rope 的头维度,固定为 64。
说明
attention_mode 约束:当前仅支持 attention_mode=2(MLA-absorb模式)。使用 attention_mode=2 时,必须同时提供 query_rope 和 key_rope。
return_softmax_lse 约束:True表示返回,但图模式下不支持,False表示不返回;默认值为False。该参数仅在训练且 layout_kv 不为
"PA_BSND"场景支持。layout 约束:layout_kv 为
"PA_BSND"时,layout_query 和 layout_kv 无需一致;layout_kv 为"BSND"或"TND"时,layout_query 和 layout_kv 需保持一致。
sparse_flash_attention的计算公式定义如下:
\[\text{attention_out} = \text{softmax}(\frac{Q \cdot \tilde{K}^T}{\sqrt{d_k}}) \cdot \tilde{V}\]其中 \(\tilde{K}, \tilde{V}\) 为通过特定选取算法(如 lightning_indexer)选取的重要性较高的 Key 和 Value, 通常具有稀疏或块稀疏特性。\(d_k\) 为 \(Q, \tilde{K}\) 的头维度。
警告
这是一个实验性API,可能会发生变更或被删除。
只支持 Atlas A2/A3 推理系列产品。
- 参数:
query (Tensor) - query Tensor,对应公式中的 Q。支持的数据类型为 float16 与 bfloat16。当 layout_query 为
"BSND"时shape为 \((B, S1, N1, D)\),当 layout_query 为"TND"时shape为 \((T1, N1, D)\)。N1 支持 1/2/4/8/16/32/64/128。key (Tensor) - key Tensor,对应公式中的 K。数据类型与 query 一致。当 layout_kv 为
"BSND"时shape为 \((B, S2, N2, D)\),当 layout_kv 为"TND"时shape为 \((T2, N2, D)\),当 layout_kv 为"PA_BSND"时shape为 \((block\_num, block\_size, N2, D)\)。N2 仅支持 1。value (Tensor) - value Tensor,对应公式中的 V,数据类型以及 shape 与 key 一致。
sparse_indices (Tensor) - 稀疏索引Tensor,用于离散KV缓存访问。数据类型为 int32。当 layout_query 为
"BSND"时shape为 \((B, Q\_S, N2, sparse\_size)\),当 layout_query 为"TND"时shape为 \((Q\_T, N2, sparse\_size)\)。sparse_size 必须为 2048。scale_value (float) - 缩放系数,用作query和key矩阵乘法后的乘法标量,通常为 1.0 / (D ** 0.5)。
- 关键字参数:
block_table (Tensor,可选) - PageAttention场景下的块映射表。数据格式为 ND,数据类型为 int32,shape为 \((B,)\)。默认值为:
None。actual_seq_lengths_query (Tensor,可选) - 每个batch中query的有效token数量。数据格式为 ND,数据类型为 int32,shape为 \((B,)\)。用于变长序列的前缀和格式。默认值为:
None。actual_seq_lengths_kv (Tensor,可选) - 每个batch中key和value的有效token数量。数据格式为 ND,数据类型为 int32,shape为 \((B,)\)。用于变长序列的前缀和格式。默认值为:
None。query_rope (Tensor,可选) - MLA结构中query的RoPE信息。shape与query相同,但最后一维D替换为D_rope(64)。当 attention_mode=2(MLA-absorb模式)时必须提供。 默认值为:
None。key_rope (Tensor,可选) - MLA结构中key的RoPE信息。shape与key相同,但最后一维D替换为D_rope(64)。当 attention_mode=2(MLA-absorb模式)时必须提供。 默认值为:
None。sparse_block_size (int,可选) - 稀疏阶段的块大小,用于重要性分数计算。取值范围为 [1, 128],必须是2的幂次方。默认值为:
1。layout_query (str,可选) - query的数据布局格式。支持
"BSND"和"TND"。默认值为:"BSND"。layout_kv (str,可选) - key的数据布局格式。支持
"TND"、"BSND"和"PA_BSND"。为"PA_BSND"时,无需与 layout_query 取值保持一致;为"BSND"或"TND"时,需要与 layout_query 取值保持一致。默认值为:"BSND"。sparse_mode (int,可选) - 稀疏模式。仅支持
0(全部计算)与3(rightDownCausal 掩码模式)。默认值为:3。pre_tokens (int,可选) - 稀疏计算中向前计算的token数量。默认值为:
2^63-1。next_tokens (int,可选) - 稀疏计算中向后计算的token数量。默认值为:
2^63-1。attention_mode (int,可选) - 注意力模式。当前仅支持
2(MLA-absorb 模式),即计算过程中将 query 和 key 的 nope 部分分别与 query_rope 和 key_rope 的 rope 部分沿头维度(D)拼接。使用该模式时必须同时提供 query_rope 和 key_rope。默认值为:0。return_softmax_lse (bool,可选) - 是否返回有效的 softmax_max 和 softmax_sum。为
True时返回,但图模式下不支持;为False时不返回。该参数仅在训练且 layout_kv 不为"PA_BSND"时支持。默认值为:False。
- 返回:
包含3个Tensor的元组。
attention_out (Tensor) - 注意力输出,数据类型与 query 一致。当 layout_query 为
"BSND"时shape为 \((B, S1, N1, D)\),当 layout_query 为"TND"时shape为 \((T1, N1, D)\)。softmax_max (Tensor) - softmax中的最大值,数据类型为 float32。当 layout_query 为
"BSND"时shape为 \((B, N2, S1, N1/N2)\),当 layout_query 为"TND"时shape为 \((N2, T1, N1/N2)\)。softmax_sum (Tensor) - softmax中的求和值,数据类型为 float32。shape与 softmax_max 一致。
- 异常:
TypeError - query 不是 Tensor。
TypeError - key 不是 Tensor。
TypeError - value 不是 Tensor。
TypeError - sparse_indices 不是 Tensor。
TypeError - query、key、value 的数据类型不一致。
TypeError - query、key、value 的数据类型不是 float16 或 bfloat16。
RuntimeError - layout_query 不是
"BSND"或"TND"。RuntimeError - layout_kv 不是
"BSND"、"TND"或"PA_BSND"。RuntimeError - sparse_block_size 不在范围 [1, 128] 内或不是2的幂次方。
- 支持平台:
Ascend
样例:
>>> import numpy as np >>> from mindspore import Tensor, ops >>> b, s1, s2, n1, n2, d, d_rope = 4, 1, 8192, 128, 1, 512, 64 >>> sparse_block_count = 2048 >>> query = Tensor(np.random.randn(b, s1, n1, d).astype(np.float16)) >>> key = Tensor(np.random.randn(b, s2, n2, d).astype(np.float16)) >>> value = key.copy() >>> query_rope = Tensor(np.random.randn(b, s1, n1, d_rope).astype(np.float16)) >>> key_rope = Tensor(np.random.randn(b, s2, n2, d_rope).astype(np.float16)) >>> sparse_indices = Tensor(np.random.randint(0, s2, (b, s1, n2, sparse_block_count)).astype(np.int32)) >>> scale_value = 1.0 / (d ** 0.5) >>> attention_out, softmax_max, softmax_sum = ops.sparse_flash_attention( ... query, key, value, sparse_indices, scale_value, ... query_rope=query_rope, key_rope=key_rope, attention_mode=2, return_softmax_lse=True) >>> print(attention_out.shape) (4, 1, 128, 512)