mindscience.models.transformer.TransformerBlock

class mindscience.models.transformer.TransformerBlock(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[source]

TransformerBlock comprises a MultiHeadAttention and a FeedForward layer.

Parameters
  • in_channels (int) – The input channels.

  • num_heads (int) – The number of attention heads.

  • enable_flash_attn (bool, optional) – Whether use flash attention. FlashAttention only supports Ascend backend. FlashAttention proposed in FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Default: False.

  • fa_dtype (mindspore.dtype, optional) – FlashAttention compute dtype. Choose from mstype.bfloat16, mstype.float16. Default: mstype.bfloat16, indicates mindspore.bfloat16.

  • drop_mode (str, optional) – Dropout method. Support "dropout" or "droppath". Default: "dropout".

  • dropout_rate (float, optional) – The drop rate of dropout layer, greater than 0 and less equal than 1. Default: 0.0.

  • compute_dtype (mindspore.dtype, optional) – Compute dtype. Default: mstype.float32, indicates mindspore.float32.

Inputs:
  • x (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

  • mask (Tensor, optional) - Tensor with shape \((sequence\_len, sequence\_len)\) or \((batch\_size, 1, sequence\_len, sequence\_len)\). Default: None.

Outputs:
  • output (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

Examples

>>> from mindspore import ops
>>> from mindscience.models.transformer.attention import TransformerBlock
>>> model = TransformerBlock(in_channels=256, num_heads=4)
>>> x = ops.rand((4, 100, 256))
>>> output = model(x)
>>> print(output.shape)
(4, 100, 256)