mindspore.ops.ring_attention_update

View Source On Gitee
mindspore.ops.ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH')[source]

The RingAttentionUpdate operator updates the output of two FlashAttention operations based on their respective softmax max and softmax sum values.

  • S: Sequence length

  • B: Batch dimension

  • H: Hidden layer size, equals to N * D

  • T: time, equals to B*S

  • N: Number of attention heads

  • D: Head dimension

Warning

  • It is only supported on Atlas A2 Training Series Products.

  • This is an experimental API that is subject to change or deletion.

  • When layout is "TND", the last dimension of prev_attn_out must be a multiple of 64.

  • When layout is "TND", actual_seq_qlen is mandatory.

  • When layout is "TND", N * D must satisfy the constraint: \((\text{AlignUp}(N*D, 64)*(DataSize*6+8))+(\text{AlignUp}(N*8, 64)*56) <= 192*1024\). \(DataSize\) is 4 bytes when prev_attn_out dtype is float32, 2 bytes when dtype is float16 / bfloat16.

  • When layout is "TND", if actual_seq_qlen is not a non-decreasing sequence from 0 to T, the result is undefined.

Parameters
  • prev_attn_out (Tensor) – Output of the first FlashAttention operation. The dtype is float16, float32, bfloat16. The shape is \((S, B, H)\) or \((T, N, D)\).

  • prev_softmax_max (Tensor) – The max values from the first FlashAttention softmax computation. The dtype float32. The shape is \((B, N, S, 8)\) or \((T, N, 8)\). The last dimension contains 8 identical values, which must be positive.

  • prev_softmax_sum (Tensor) – The sum values from the first FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.

  • cur_attn_out (Tensor) – Output of the second FlashAttention operation. It has the same shape and dtype as prev_attn_out.

  • cur_softmax_max (Tensor) – The max values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.

  • cur_softmax_sum (Tensor) – The sum values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.

  • actual_seq_qlen (Tensor, optional) – Cumulative sequence length, starting from 0. Required if layout is "TND". Does not take effect if layout is "SBH". The tensor must be 1D and contain non-decreasing integer values starting from 0 to T. Default: None.

  • layout (str, optional) – Indicates the input layout, currently support "TND" and "SBH". Default: "SBH".

Returns

tuple (Tensor), tuple of 3 tensors.

  • attn_out (Tensor) - The updated attention out, with the same shape and dtype as prev_attn_out.

  • softmax_max (Tensor) - The updated softmax max values, with the same shape and dtype as prev_softmax_max.

  • softmax_sum (Tensor) - The updated softmax sum values, with the same shape and dtype as prev_softmax_max.

Raises
  • RuntimeError – If layout is "TND", and prev_attn_out's last dimension is not aligned to 64.

  • RuntimeError – If layout is "TND", and actual_seq_qlen is not provided.

  • RuntimeError – If layout is "TND", and actual_seq_qlen is not a non-decreasing sequence from 0 to T.

  • RuntimeError – If layout is "TND", and prev_attn_out exceeds the size constraints.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> np.random.seed(123)
>>> S, B, H, N= 4, 6, 16, 8
>>> prev_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
>>> prev_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> prev_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> cur_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
>>> cur_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> cur_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> inputs_np = [prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum]
>>> inputs_ms = [Tensor(item) for item in inputs_np]
>>> out = ops.ring_attention_update(*inputs_ms)
>>> print(out[0].shape)
(4, 6, 16)