mindscience.models.transformer.MultiHeadAttention

class mindscience.models.transformer.MultiHeadAttention(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[源代码]

多头注意力,提出于 Attention Is All You Need

参数:
  • in_channels (int) - 输入通道。

  • num_heads (int) - 注意力头的数量。

  • enable_flash_attn (bool) - 是否使用闪存注意力。闪存注意力仅支持 Ascend 后端。闪存注意力提出于 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness。默认值: False

  • fa_dtype (mindspore.dtype) - FlashAttention 计算数据类型。从 mstype.bfloat16mstype.float16 中选择。默认值: mstype.bfloat16,表示 mindspore.bfloat16

  • drop_mode (str) - 丢弃方法,支持 "dropout""droppath"。默认值: "dropout"

  • dropout_rate (float) - dropout 层的丢弃率,大于 0 且小于等于 1。默认值: 0.0

  • compute_dtype (mindspore.dtype) - 计算数据类型。默认值: mstype.float32,表示 mindspore.float32

输入:
  • x (Tensor) - Tensor,形状为 \((batch\_size, sequence\_len, in\_channels)\)

  • attn_mask (Tensor, 可选) - Tensor,形状为 \((sequence\_len, sequence\_len)\)\((batch\_size, 1, sequence\_len, sequence\_len)\)。默认值: None

  • key_padding_mask (Tensor, 可选) - Tensor,形状为 \((batch\_size, sequence\_len)\)。默认值: None

输出:
  • output (Tensor) - 形状为 \((batch\_size, sequence\_len, in\_channels)\)

样例:

>>> from mindspore import ops
>>> from mindscience.models.transformer.attention import MultiHeadAttention
>>> model = MultiHeadAttention(in_channels=512, num_heads=4)
>>> x = ops.rand((2, 32, 512))
>>> mask_shape = (32, 32)
>>> mask = ops.ones(mask_shape)
>>> output = model(x, mask)
>>> print(output.shape)
(2, 32, 512)