mindspore.ops.ring_attention_update
- mindspore.ops.ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH')[source]
The RingAttentionUpdate operator updates the output of two FlashAttention operations based on their respective softmax max and softmax sum values.
S: Sequence length
B: Batch dimension
H: Hidden layer size, equals to N * D
T: time, equals to B*S
N: Number of attention heads
D: Head dimension
Warning
It is only supported on Atlas A2 Training Series Products.
This is an experimental API that is subject to change or deletion.
When layout is
"TND", the last dimension of prev_attn_out must be a multiple of 64.When layout is
"TND", actual_seq_qlen is mandatory.When layout is
"TND", N * D must satisfy the constraint: \((\text{AlignUp}(N*D, 64)*(DataSize*6+8))+(\text{AlignUp}(N*8, 64)*56) <= 192*1024\). \(DataSize\) is 4 bytes when prev_attn_out dtype is float32, 2 bytes when dtype is float16 / bfloat16.When layout is
"TND", if actual_seq_qlen is not a non-decreasing sequence from 0 to T, the result is undefined.
- Parameters
prev_attn_out (Tensor) – Output of the first FlashAttention operation. The dtype is float16, float32, bfloat16. The shape is \((S, B, H)\) or \((T, N, D)\).
prev_softmax_max (Tensor) – The max values from the first FlashAttention softmax computation. The dtype float32. The shape is \((B, N, S, 8)\) or \((T, N, 8)\). The last dimension contains 8 identical values, which must be positive.
prev_softmax_sum (Tensor) – The sum values from the first FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
cur_attn_out (Tensor) – Output of the second FlashAttention operation. It has the same shape and dtype as prev_attn_out.
cur_softmax_max (Tensor) – The max values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
cur_softmax_sum (Tensor) – The sum values from the second FlashAttention softmax computation. It has the same shape and dtype as prev_softmax_max.
actual_seq_qlen (Tensor, optional) – Cumulative sequence length, starting from 0. Required if layout is
"TND". Does not take effect if layout is"SBH". The tensor must be 1D and contain non-decreasing integer values starting from 0 to T. Default:None.layout (str, optional) – Indicates the input layout, currently support
"TND"and"SBH". Default:"SBH".
- Returns
tuple (Tensor), tuple of 3 tensors.
attn_out (Tensor) - The updated attention out, with the same shape and dtype as prev_attn_out.
softmax_max (Tensor) - The updated softmax max values, with the same shape and dtype as prev_softmax_max.
softmax_sum (Tensor) - The updated softmax sum values, with the same shape and dtype as prev_softmax_max.
- Raises
RuntimeError – If layout is
"TND", and prev_attn_out's last dimension is not aligned to 64.RuntimeError – If layout is
"TND", and actual_seq_qlen is not provided.RuntimeError – If layout is
"TND", and actual_seq_qlen is not a non-decreasing sequence from 0 to T.RuntimeError – If layout is
"TND", and prev_attn_out exceeds the size constraints.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, ops >>> np.random.seed(123) >>> S, B, H, N= 4, 6, 16, 8 >>> prev_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32) >>> prev_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> prev_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> cur_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32) >>> cur_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> cur_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32) >>> inputs_np = [prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum] >>> inputs_ms = [Tensor(item) for item in inputs_np] >>> out = ops.ring_attention_update(*inputs_ms) >>> print(out[0].shape) (4, 6, 16)