mindspore.ops.sparse_flash_attention
- mindspore.ops.sparse_flash_attention(query, key, value, sparse_indices, scale_value, *, block_table=None, actual_seq_lengths_query=None, actual_seq_lengths_kv=None, query_rope=None, key_rope=None, sparse_block_size=1, layout_query='BSND', layout_kv='BSND', sparse_mode=3, pre_tokens=2 ^ 63 - 1, next_tokens=2 ^ 63 - 1, attention_mode=0, return_softmax_lse=False) Tuple[Tensor, Tensor, Tensor][source]
Computes Sparse Flash Attention. Sparse Flash Attention (SFA) is an efficient attention computation module for long sequence inference scenarios. This module significantly reduces computation by "computing only key parts", selecting important Key and Value based on sparse indices for attention computation. The following are the dimension symbols used in the formula and parameter shapes:
B: batch dimension.
S1: sequence length of query (Q in the formula).
S2: sequence length of key and value (K and V in the formula).
N1: number of heads for query, supports 1/2/4/8/16/32/64/128.
N2: number of heads for key and value, only supports 1.
D: head dimension of query / key / value, fixed at 512.
D_rope: head dimension of query_rope / key_rope, fixed at 64.
Note
attention_mode constraint: Currently only attention_mode=2 (MLA-absorb mode) is supported. When using attention_mode=2, query_rope and key_rope must be provided.
return_softmax_lse constraint: When set to True, valid softmax_max and softmax_sum are returned, but not supported in graph mode. When set to False, they are not returned. Default is False. This parameter is only supported in training and when layout_kv is not
"PA_BSND".layout constraint: When layout_kv is
"PA_BSND", it need not be consistent with layout_query; when layout_kv is"BSND"or"TND", it must be consistent with layout_query.
The formula is:
\[\text{attention_out} = \text{softmax}(\frac{Q \cdot \tilde{K}^T}{\sqrt{d_k}}) \cdot \tilde{V}\]where \(\tilde{K}, \tilde{V}\) are the Key and Value with higher importance selected by certain selection algorithms (such as lightning_indexer), generally with sparse or block-sparse characteristics. \(d_k\) is the head dimension of \(Q, \tilde{K}\).
Warning
This is an experimental API that may be changed or deleted.
Only supported on Atlas A2/A3 inference series products.
- Parameters
query (Tensor) – The query tensor, corresponding to Q in the formula. Dtype is bfloat16 or float16. When layout_query is
"BSND", shape is \((B, S1, N1, D)\). When layout_query is"TND", shape is \((T1, N1, D)\). N1 supports 1/2/4/8/16/32/64/128.key (Tensor) – The key tensor, corresponding to K in the formula. Dtype same as query. When layout_kv is
"BSND", shape is \((B, S2, N2, D)\). When layout_kv is"TND", shape is \((T2, N2, D)\). When layout_kv is"PA_BSND", shape is \((block\_num, block\_size, N2, D)\). N2 only supports 1.value (Tensor) – The value tensor, corresponding to V in the formula. Same shape as key.
sparse_indices (Tensor) – Sparse index Tensor. Indices for discrete KV cache access. Dtype is int32. When layout_query is
"BSND", shape is \((B, Q\_S, N2, sparse\_size)\). When layout_query is"TND", shape is \((Q\_T, N2, sparse\_size)\). sparse_size must be 2048.scale_value (float) – Scale coefficient, used as the scalar after query and key matrix multiplication. Typically 1.0 / (D ** 0.5).
- Keyword Arguments
block_table (Tensor, optional) – Block mapping table for PageAttention KV cache storage. Data format is ND, dtype is int32, shape is \((B,)\). Default:
None.actual_seq_lengths_query (Tensor, optional) – Valid token count for each batch in query. Data format is ND, dtype is int32, shape is \((B,)\). Used in prefix-sum format for variable sequence lengths. Default:
None.actual_seq_lengths_kv (Tensor, optional) – Valid token count for each batch in key and value. Data format is ND, dtype is int32, shape is \((B,)\). Used in prefix-sum format for variable sequence lengths. Default:
None.query_rope (Tensor, optional) – RoPE information for query in MLA structure. Required when attention_mode=2 (MLA-absorb mode). Shape is same as query but last dimension D is replaced with D_rope (64). Default:
None.key_rope (Tensor, optional) – RoPE information for key in MLA structure. Required when attention_mode=2 (MLA-absorb mode). Shape is same as key but last dimension D is replaced with D_rope (64). Default:
None.sparse_block_size (int, optional) – Block size for sparse stage, used in importance score calculation. Range is [1, 128] and must be a power of 2. Default:
1.layout_query (str, optional) – Data layout format for query. Supports
"BSND"and"TND". Default:"BSND".layout_kv (str, optional) – Data layout format for key. Supports
"TND","BSND"and"PA_BSND". When"PA_BSND", need not be consistent with layout_query; when"BSND"or"TND", must be consistent with layout_query. Default:"BSND".sparse_mode (int, optional) – Sparse mode. Only
0(full computation) and3(rightDownCausal mask mode) are supported. Default:3.pre_tokens (int, optional) – For sparse computation, indicates how many tokens to calculate forward. Default:
2^63-1.next_tokens (int, optional) – For sparse computation, indicates how many tokens to calculate backward. Default:
2^63-1.attention_mode (int, optional) – Attention mode. Currently only
2(MLA-absorb mode) is supported; query_rope and key_rope must be provided when using this mode. Default:0.return_softmax_lse (bool, optional) – Whether to return softmax_max and softmax_sum. When
True, valid values are returned but not supported in graph mode; whenFalse, they are not returned. Only supported in training and when layout_kv is not"PA_BSND". Default:False.
- Returns
Tuple of 3 Tensors.
attention_out (Tensor) - Attention output. Dtype is bfloat16 or float16. When layout_query is
"BSND", shape is \((B, S1, N1, D)\). When layout_query is"TND", shape is \((T1, N1, D)\).softmax_max (Tensor) - Max value in softmax. Dtype is float32. When layout_query is
"BSND", shape is \((B, N2, S1, N1/N2)\). When layout_query is"TND", shape is \((N2, T1, N1/N2)\).softmax_sum (Tensor) - Sum value in softmax. Dtype is float32. Shape is same as softmax_max.
- Raises
TypeError – If query is not a Tensor.
TypeError – If key is not a Tensor.
TypeError – If value is not a Tensor.
TypeError – If sparse_indices is not a Tensor.
TypeError – If the dtypes of query, key, value are inconsistent.
TypeError – If the dtype of query, key, or value is not float16 or bfloat16.
RuntimeError – If layout_query is not
"BSND"or"TND".RuntimeError – If layout_kv is not
"BSND","TND"or"PA_BSND".RuntimeError – If sparse_block_size is not in range [1, 128] or not a power of 2.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> from mindspore import Tensor, ops >>> b, s1, s2, n1, n2, d, d_rope = 4, 1, 8192, 128, 1, 512, 64 >>> sparse_block_count = 2048 >>> query = Tensor(np.random.randn(b, s1, n1, d).astype(np.float16)) >>> key = Tensor(np.random.randn(b, s2, n2, d).astype(np.float16)) >>> value = key.copy() >>> query_rope = Tensor(np.random.randn(b, s1, n1, d_rope).astype(np.float16)) >>> key_rope = Tensor(np.random.randn(b, s2, n2, d_rope).astype(np.float16)) >>> sparse_indices = Tensor(np.random.randint(0, s2, (b, s1, n2, sparse_block_count)).astype(np.int32)) >>> scale_value = 1.0 / (d ** 0.5) >>> attention_out, softmax_max, softmax_sum = ops.sparse_flash_attention( ... query, key, value, sparse_indices, scale_value, ... query_rope=query_rope, key_rope=key_rope, attention_mode=2, return_softmax_lse=True) >>> print(attention_out.shape) (4, 1, 128, 512)