mindspore.ops.nsa_compress_attention
- mindspore.ops.nsa_compress_attention(query, key, value, scale_value, head_num, compress_block_size, compress_stride, select_block_size, select_block_count, *, topk_mask=None, atten_mask=None, actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen)[source]
Use NSA Compress Attention algorithm for attention compression computation (Ascend).
The operator computes compressed attention through the following steps:
Compute attention scores and apply mask:
\[QK = scale \cdot query \cdot key^T\]Apply attention mask (if provided):
\[QK = \text{atten_mask}(QK, atten\_mask)\]Compute compressed attention scores:
\[P_{cmp} = \text{Softmax}(QK)\]Compute attention output:
\[\text{attentionOut} = P_{cmp} \cdot value\]Compute importance score:
\[P_{slc}[j] = \sum_{m=0}^{l'/d-1} \sum_{n=0}^{l/d-1} P_{cmp}[l'/d \cdot j - m - n]\]where \(l'\) is the select KV sequence length, \(l\) is the compress KV sequence length, and \(d\) is compress_stride.
Aggregate importance scores across heads:
\[P_{slc}' = \sum_{h=1}^{H} P_{slc}^h\]where \(H\) is the number of heads.
Apply TopK mask:
\[P_{slc}' = \text{topk_mask}(P_{slc}')\]Select TopK indices:
\[\text{topkIndices} = \text{topk}(P_{slc}')\]
Note
Internal layout is fixed to "TND" on Ascend platform.
actual_seq_qlen , actual_cmp_seq_kvlen , actual_sel_seq_kvlen use prefix sum mode.
compress_block_size must be a multiple of 16, supported range: 16 to 128.
compress_stride must be a multiple of 16, supported range: 16 to 64.
select_block_size must be a multiple of 16, supported range: 16 to 128.
select_block_count supported range: 1 to 32.
compress_block_size >= compress_stride
select_block_size >= compress_block_size
select_block_size % compress_stride == 0
Head dimension constraints: Query and key head dimension D1 must be the same, and D1 >= D2 (key head dimension is greater than or equal to value head dimension).
Head count constraints: N1 >= N2 (query head count is greater than or equal to key head count) and N1 % N2 == 0 (query head count must be a multiple of key head count).
D1 and D2 must be multiples of 16.
- Parameters
query (Tensor) – Query tensor with shape (T1, N1, D1), where T1 is query sequence length, N1 is query head count, D1 is head dimension. Dtype float16 or bfloat16. Required.
key (Tensor) – Key tensor with shape (T2, N2, D1), where T2 is key sequence length, N2 is key head count, D1 is head dimension (same as query). Dtype float16 or bfloat16. Required.
value (Tensor) – Value tensor with shape (T2, N2, D2), where T2 is value sequence length, N2 is value head count (same as key), D2 is value head dimension. Dtype float16 or bfloat16. Required.
scale_value (float) – Scale factor for attention scores. Required.
head_num (int) – Number of attention heads, should equal query head count N1. Required.
compress_block_size (int) – Compress sliding window size. Required.
compress_stride (int) – Distance between adjacent sliding windows. Required.
select_block_size (int) – Select block size. Required.
select_block_count (int) – Number of select blocks. Required.
- Keyword Arguments
topk_mask (Tensor, optional) – TopK mask tensor with shape (S1, S2), where S1 is query sequence length, S2 is select KV sequence length. Dtype bool. Default: None.
atten_mask (Tensor, optional) – Attention mask tensor with shape (S1, S2), where S1 is query sequence length, S2 is compress KV sequence length. Dtype bool. Default: None.
actual_seq_qlen (Union[tuple[int], list[int]]) – Batch query sequence lengths (prefix sum).
actual_cmp_seq_kvlen (Union[tuple[int], list[int]]) – Batch compress KV sequence lengths (prefix sum).
actual_sel_seq_kvlen (Union[tuple[int], list[int]]) – Batch select KV sequence lengths (prefix sum).
- Returns
tuple, containing four tensors.
attention_out (Tensor) - Attention output tensor, shape (T1, N1, D2), dtype same as query.
topk_indices_out (Tensor) - TopK indices tensor, shape (T1, N2, select_block_count), where T1 is query sequence length, N2 is key head count, dtype int32.
softmax_max_out (Tensor) - Softmax max intermediate result, shape (T1, N1, 8), dtype float32.
softmax_sum_out (Tensor) - Softmax sum intermediate result, shape (T1, N1, 8), dtype float32.
- Raises
TypeError – If input parameter types are incorrect.
ValueError – If input tensor shapes or parameters don't satisfy constraints.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> from mindspore import Tensor, ops >>> T, N, D = 256, 8, 128 >>> query = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> key = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> value = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> topk_mask = Tensor(np.ones((T, T), dtype=np.bool_)) >>> atten_mask = Tensor(np.ones((T, T), dtype=np.bool_)) >>> actual_seq_qlen = [T] >>> actual_cmp_seq_kvlen = [T] >>> actual_sel_seq_kvlen = [T] >>> compress_block_size = 32 >>> compress_stride = 16 >>> select_block_size = 64 >>> select_block_count = 16 >>> scale_value = 1.0 / (D ** 0.5) >>> head_num = N >>> attention_out, topk_indices_out, softmax_max_out, softmax_sum_out = ops.nsa_compress_attention( ... query, key, value, scale_value, head_num, compress_block_size, compress_stride, ... select_block_size, select_block_count, topk_mask, atten_mask, ... actual_seq_qlen, actual_cmp_seq_kvlen, actual_sel_seq_kvlen) >>> print(attention_out.shape) (256, 8, 128)