mindflow.cell.MultiHeadAttention

View Source On Gitee
class mindflow.cell.MultiHeadAttention(in_channels, num_heads, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[source]

Multi Head Attention proposed in Attention Is All You Need.

Parameters
  • in_channels (int) – The input channels.

  • num_heads (int) – The number of attention heads.

  • drop_mode (str) – Dropout method, dropout or droppath. Default: dropout.

  • dropout_rate (float) – The drop rate of dropout layer, greater than 0 and less equal than 1. Default: 0.0.

  • compute_dtype (mindspore.dtype) – Compute dtype. Default: mstype.float32, indicates mindspore.float32.

Inputs:
  • x (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

  • attn_mask (Tensor) - Tensor with shape \((batch\_size, sequence\_len, sequence\_len)\) or \((sequence\_len, sequence\_len)\) or \((batch\_size, num_heads, sequence\_len, sequence\_len)\).

  • key_padding_mask (Tensor) - Tensor with shape \((batch\_size, sequence\_len)\) or \((batch\_size, sequence\_len, sequence\_len)\) or \((batch\_size, num_heads, sequence\_len, sequence\_len)\).

Outputs:
  • output (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

Supported Platforms:

Ascend CPU

Examples

>>> from mindspore import ops
>>> from mindflow.cell import MultiHeadAttention
>>> model = MultiHeadAttention(in_channels=512, num_heads=4)
>>> x = ops.rand((2, 32, 512))
>>> mask_shape = (2, 4, 32, 32)
>>> mask = ops.ones(mask_shape)
>>> output = model(x, mask)
>>> print(output.shape)
(2, 32, 512)