mindspore.ops.ring_attention_update

查看源文件
mindspore.ops.ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH')[源代码]

RingAttentionUpdate算子功能是将两次FlashAttention的输出根据其不同的softmax的max和sum更新。

  • S:序列长度

  • B:batch维

  • H:隐藏层大小,等于N*D

  • T:time,等于B*S

  • N:注意力头数

  • D:头维度

警告

  • 仅支持 Atlas A2 训练系列产品。

  • 这是一个实验性API,后续可能修改或删除。

  • layout"TND" 时, prev_attn_out 的最后一个维度需要为64的倍数。

  • layout"TND" 时, actual_seq_qlen 为必填。

  • layout"TND" 时,输入Shape不能过大。需要注意N*D的大小,限制为: \((\text{AlignUp}(N*D, 64)*(DataSize*6+8))+(\text{AlignUp}(N*8, 64)*56) <= 192*1024\)prev_attn_out 数据类型为float32时,\(DataSize\) 为4;数据类型为float16 / bfloat16时,\(DataSize\) 为2。

  • layout"TND" 时,如果 actual_seq_qlen 不是从0到T的非递减序列,计算结果未定义。

参数:
  • prev_attn_out (Tensor) - 第一次FlashAttention的输出,数据类型支持float16、float32、bfloat16,输入shape为 \((S, B, H)\)\((T, N, D)\)

  • prev_softmax_max (Tensor) - 第一次FlashAttention的softmax的max结果,数据类型支持float32,输入shape为 \((B,N,S,8)\)\((T,N,8)\),最后一维8个数字相同,且需要为正数。

  • prev_softmax_sum (Tensor) - 第一次FlashAttention的softmax的sum结果,数据类型与shape和 prev_softmax_max 保持一致。

  • cur_attn_out (Tensor) - 第二次FlashAttention的输出,数据类型与shape和 prev_attn_out 保持一致。

  • cur_softmax_max (Tensor) - 第二次FlashAttention的softmax的max结果,数据类型与shape和 prev_softmax_max 保持一致。

  • cur_softmax_sum (Tensor) - 第二次FlashAttention的softmax的sum结果,数据类型与shape和 prev_softmax_max 保持一致。

  • actual_seq_qlen (Tensor,可选) - 从0开始的sequence length的累加,数据类型支持int64。当数据排布 layout"TND" 时,需要传入该参数;当 layout"SBH" 时,该参数不生效。需要是一个从0开始非递减至T的整数1D Tensor。 默认值: None

  • layout (str,可选) - attn_out相关输入的数据排布。当前支持 "TND""SBH" 。默认值: "SBH"

返回:

tuple (Tensor),表示3个Tensor组成的tuple。

  • attn_out (Tensor) - 更新后的attention输出,数据类型与输出shape和 prev_attn_out 保持一致。

  • softmax_max (Tensor) - 更新后的softmax的max,数据类型与输出shape和 prev_softmax_max 保持一致。

  • softmax_sum (Tensor) - 更新后的softmax的sum,数据类型与输出shape和 prev_softmax_max 保持一致。

异常:
  • RuntimeError - 如果 layout"TND" 的时候, prev_attn_out 最后一维没有64对齐。

  • RuntimeError - 如果 layout"TND" 的时候, actual_seq_qlen 没有传入。

  • RuntimeError - 如果 layout"TND" 的时候, actual_seq_qlen 不是从0到T的非递减序列。

  • RuntimeError - 如果 layout"TND" 的时候, prev_attn_out shape过大,不满足约束。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> np.random.seed(123)
>>> S, B, H, N= 4, 6, 16, 8
>>> prev_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
>>> prev_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> prev_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> cur_attn_out = np.random.uniform(-1.0, 1.0, size=(S, B, H)).astype(np.float32)
>>> cur_softmax_max = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> cur_softmax_sum = np.random.uniform(-1.0, 1.0, size=(B, N, S, 8)).astype(np.float32)
>>> inputs_np = [prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum]
>>> inputs_ms = [Tensor(item) for item in inputs_np]
>>> out = ops.ring_attention_update(*inputs_ms)
>>> print(out[0].shape)
(4, 6, 16)