mindspore.ops.nsa_compress
- mindspore.ops.nsa_compress(input, weight, compress_block_size, compress_stride, *, actual_seq_len) Tensor[source]
Compress the KV sequence dimension using the NSA Compress algorithm to reduce attention computation in long-context training.
Note
Layout is fixed to
"TND".actual_seq_len is interpreted as prefix-sum mode. It must be a non-decreasing integer sequence and the last element must equal T. In prefix-sum mode, if per-segment lengths are [s1, s2, s3], then actual_seq_len = (s1, s1 + s2, s1 + s2 + s3) and its last value equals T.
Windows are formed independently inside each segment; there is no cross-segment window. Compressed outputs from all segments are concatenated in the original order.
D must be a multiple of 16 and no greater than 256; 1 <= N <= 128.
compress_block_size must be a multiple of 16 and no greater than 128;
compress_stride must be a multiple of 16 and 16 <= compress_stride <= compress_block_size.
- Parameters
- Keyword Arguments
actual_seq_len (Union[tuple[int], list[int]]) – Per-batch sequence lengths in prefix-sum mode. The sequence must be non-decreasing and its last element must equal T.
- Returns
Tensor. Shape is (T', N, D) with the same dtype as input. The first dimension \(T'\) is determined jointly by (actual_seq_len, compress_block_size, compress_stride). Let per-segment lengths be \(L_i\) (derived from actual_seq_len as prefix-sums differences). Then \(T'\) is given by \(T' = \sum_i \max\big(0,\; 1 + \big\lfloor \frac{L_i - \mathrm{compress\_block\_size}} {\mathrm{compress\_stride}} \big\rfloor\big)\).
- Raises
TypeError – If input is not a Tensor.
TypeError – If weight is not a Tensor.
TypeError – If the dtypes of input and weight are inconsistent.
TypeError – If the dtype is not float16/bfloat16.
TypeError – If compress_block_size is not an int.
TypeError – If compress_stride is not an int.
TypeError – If actual_seq_len is not a tuple/list of ints.
RuntimeError – If the rank of input is not 3.
RuntimeError – If the rank of weight is not 2.
RuntimeError – If weight.shape[0] != compress_block_size.
RuntimeError – If weight.shape[1] != N (where N is the second dimension of input).
RuntimeError – If D % 16 != 0.
RuntimeError – If D > 256.
RuntimeError – If N < 1.
RuntimeError – If N > 128.
RuntimeError – If compress_block_size is not a multiple of 16.
RuntimeError – If compress_block_size is not in [16, 128].
RuntimeError – If compress_stride is not a multiple of 16.
RuntimeError – If compress_stride is not in [16, compress_block_size].
RuntimeError – If actual_seq_len is empty.
RuntimeError – If actual_seq_len is not non-decreasing.
RuntimeError – If actual_seq_len contains non-positive values.
RuntimeError – If the last element of actual_seq_len does not equal T.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> from mindspore import Tensor, ops >>> N, D, block, stride = 8, 128, 64, 16 >>> per_segments = [80, 96, 80] >>> actual_seq = tuple(np.cumsum(per_segments, dtype=np.int64).tolist()) >>> T = int(actual_seq[-1]) >>> x = Tensor(np.random.randn(T, N, D).astype(np.float16)) >>> w = Tensor(np.random.randn(block, N).astype(np.float16)) >>> y = ops.nsa_compress(x, w, block, stride, actual_seq_len=actual_seq) >>> print(y.shape) (7, 8, 128)