mindspore.ops.nsa_compress

View Source On Gitee
mindspore.ops.nsa_compress(input, weight, compress_block_size, compress_stride, *, actual_seq_len) Tensor[source]

Compress the KV sequence dimension using the NSA Compress algorithm to reduce attention computation in long-context training.

Note

  • Layout is fixed to "TND".

  • actual_seq_len is interpreted as prefix-sum mode. It must be a non-decreasing integer sequence and the last element must equal T. In prefix-sum mode, if per-segment lengths are [s1, s2, s3], then actual_seq_len = (s1, s1 + s2, s1 + s2 + s3) and its last value equals T.

  • Windows are formed independently inside each segment; there is no cross-segment window. Compressed outputs from all segments are concatenated in the original order.

  • D must be a multiple of 16 and no greater than 256; 1 <= N <= 128.

  • compress_block_size must be a multiple of 16 and no greater than 128;

  • compress_stride must be a multiple of 16 and 16 <= compress_stride <= compress_block_size.

Parameters
  • input (Tensor) – Shape (T, N, D), dtype float16 or bfloat16.

  • weight (Tensor) – Shape (compress_block_size, N), same dtype as input.

  • compress_block_size (int) – Sliding window size for compression.

  • compress_stride (int) – Step between adjacent windows.

Keyword Arguments

actual_seq_len (Union[tuple[int], list[int]]) – Per-batch sequence lengths in prefix-sum mode. The sequence must be non-decreasing and its last element must equal T.

Returns

Tensor. Shape is (T', N, D) with the same dtype as input. The first dimension \(T'\) is determined jointly by (actual_seq_len, compress_block_size, compress_stride). Let per-segment lengths be \(L_i\) (derived from actual_seq_len as prefix-sums differences). Then \(T'\) is given by \(T' = \sum_i \max\big(0,\; 1 + \big\lfloor \frac{L_i - \mathrm{compress\_block\_size}} {\mathrm{compress\_stride}} \big\rfloor\big)\).

Raises
Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> N, D, block, stride = 8, 128, 64, 16
>>> per_segments = [80, 96, 80]
>>> actual_seq = tuple(np.cumsum(per_segments, dtype=np.int64).tolist())
>>> T = int(actual_seq[-1])
>>> x = Tensor(np.random.randn(T, N, D).astype(np.float16))
>>> w = Tensor(np.random.randn(block, N).astype(np.float16))
>>> y = ops.nsa_compress(x, w, block, stride, actual_seq_len=actual_seq)
>>> print(y.shape)
(7, 8, 128)