mindscience.models.transformer.MultiHeadAttention
- class mindscience.models.transformer.MultiHeadAttention(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32)[source]
Multi Head Attention proposed in Attention Is All You Need.
- Parameters
in_channels (int) – The input channels.
num_heads (int) – The number of attention heads.
enable_flash_attn (bool) – Whether use flash attention. FlashAttention only supports Ascend backend. FlashAttention proposed in FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Default:
False.fa_dtype (mindspore.dtype) – FlashAttention compute dtype. Choose from mstype.bfloat16, mstype.float16. Default:
mstype.bfloat16, indicatesmindspore.bfloat16.drop_mode (str) – Dropout method, Support
"dropout"or"droppath". Default:"dropout".dropout_rate (float) – The drop rate of dropout layer, greater than 0 and less equal than 1. Default:
0.0.compute_dtype (mindspore.dtype) – Compute dtype. Default:
mstype.float32, indicatesmindspore.float32.
- Inputs:
x (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).
attn_mask (Tensor, optional) - Tensor with shape \((sequence\_len, sequence\_len)\) or \((batch\_size, 1, sequence\_len, sequence\_len)\). Default:
None.key_padding_mask (Tensor, optional) - Tensor with shape \((batch\_size, sequence\_len)\). Default:
None.
- Outputs:
output (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).
Examples
>>> from mindspore import ops >>> from mindscience.models.transformer.attention import MultiHeadAttention >>> model = MultiHeadAttention(in_channels=512, num_heads=4) >>> x = ops.rand((2, 32, 512)) >>> mask_shape = (32, 32) >>> mask = ops.ones(mask_shape) >>> output = model(x, mask) >>> print(output.shape) (2, 32, 512)