mindflow.cell.MultiHeadAttention

查看源文件
class mindflow.cell.MultiHeadAttention(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[源代码]

多头注意力机制,具体细节可以参见 Attention Is All You Need

参数:
  • in_channels (int) - 输入的输入特征维度。

  • num_heads (int) - 输出的输出特征维度。

  • enable_flash_attn (bool) - 是否使能FlashAttention。FlashAttention只支持 Ascend 后端。具体细节参见 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 。 默认值: False

  • fa_dtype (mindspore.dtype): FlashAttention计算类型。支持以下类型: mstype.bfloat16mstype.float16。默认值: mstype.bfloat16 ,表示 mindspore.bfloat16

  • drop_mode (str) - dropout方式。默认值: dropout 。支持以下类型: dropoutdroppath

  • dropout_rate (float) - dropout层丢弃的比率。取值在 [0, 1] 。默认值: 0.0

  • compute_dtype (mindspore.dtype) - 网络层的数据类型。默认值: mstype.float32 ,表示 mindspore.float32

输入:
  • x (Tensor) - shape为 \((batch\_size, sequence\_len, in\_channels)\) 的Tensor。

  • attn_mask (Tensor,可选) - shape为 \((sequence\_len, sequence\_len)\)\((batch\_size, 1, sequence\_len, sequence\_len)\) 的Tensor。默认值: None

  • key_padding_mask (Tensor,可选) - shape为 \((batch\_size, sequence\_len)\) 的Tensor。默认值: None

输出:
  • output (Tensor) - shape为 \((batch\_size, sequence\_len, in\_channels)\) 的Tensor。

支持平台:

Ascend CPU

样例:

>>> from mindspore import ops
>>> from mindflow.cell import MultiHeadAttention
>>> model = MultiHeadAttention(in_channels=512, num_heads=4)
>>> x = ops.rand((2, 32, 512))
>>> mask_shape = (2, 4, 32, 32)
>>> mask = ops.ones(mask_shape)
>>> output = model(x, mask)
>>> print(output.shape)
(2, 32, 512)