mindspore.ops.ring_attention_update
- mindspore.ops.ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH')[source]
The RingAttentionUpdate operator updates the output of two FlashAttention operations based on their respective softmax max and softmax sum values.
S: Sequence length
B: Batch dimension
H: Hidden layer size, equals to N * D
T: time, equals to B*S
N: Number of attention heads
D: Head dimension
Warning
It is only supported on Atlas A2 Training Series Products.
This is an experimental API that is subject to change or deletion.
When layout is
"TND"
, the last dimension of prev_attn_out must be a multiple of 64.When layout is
"TND"
, actual_seq_qlen is mandatory.When layout is
"TND"
, N * D must satisfy the constraint: \((\text{AlignUp}(N*D, 64)*(DataSize*6+8))+(\text{AlignUp}(N*8, 64)*56) <= 192*1024\). \(DataSize\) is 4 bytes when prev_attn_out dtype is float32, 2 bytes when dtype is float16 / bfloat16.When layout is
"TND"
, if actual_seq_qlen is not a non-decreasing sequence from 0 to T, the result is undefined.
- Parameters
prev_attn_out (Tensor) – Output of the first FlashAttention operation. The dtype is float16, float32, bfloat16. The shape is \((S, B, H)\) or \((T, N, D)\).
prev_softmax_max (Tensor) – The max values from the first FlashAttention softmax computation. The dtype float32. The shape is \((B, N, S, 8)\) or \((T, N, 8)\). The last dimension contains 8 identical values, which must be positive.
prev_softmax_sum (Tensor) – The sum values from the first FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
cur_attn_out (Tensor) – Output of the second FlashAttention operation. It has the same shape and dtype as prev_attn_out.
cur_softmax_max (Tensor) – The max values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
cur_softmax_sum (Tensor) – The sum values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
actual_seq_qlen (Tensor, optional) – Cumulative sequence length, starting from 0. Required if layout is
"TND"
. Does not take effect if layout is"SBH"
. The tensor must be 1D and contain non-decreasing integer values starting from 0 to T. Default:None
.layout (str, optional) – Indicates the input layout, currently support
"TND"
and"SBH"
. Default:"SBH"
.
- Returns
tuple (Tensor), tuple of 3 tensors.
attn_out (Tensor) - The updated attention out, with the same shape and dtype as prev_attn_out.
softmax_max (Tensor) - The updated softmax max values, with the same shape and dtype as prev_softmax_max.
softmax_sum (Tensor) - The updated softmax sum values, with the same shape and dtype as prev_softmax_max.
- Raises
RuntimeError – If layout is
"TND"
, and prev_attn_out's last dimension is not aligned to 64.RuntimeError – If layout is
"TND"
, and actual_seq_qlen is not provided.RuntimeError – If layout is
"TND"
, and actual_seq_qlen is not a non-decreasing sequence from 0 to T.RuntimeError – If layout is
"TND"
, and prev_attn_out exceeds the size constraints.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, ops >>> np.random.seed(123) >>> S, B, H, N= 4, 6, 16, 8 >>> prev_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32) >>> prev_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> prev_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> cur_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32) >>> cur_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> cur_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> inputs_np = [prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum] >>> inputs_ms = [Tensor(item) for item in inputs_np] >>> out = ops.ring_attention_update(*inputs_ms) >>> print(out[0].shape) (4, 6, 16)