mindscience.models.transformer.Attention
- class mindscience.models.transformer.Attention(in_channels, num_heads, compute_dtype=mstype.float32)[source]
Attention implementation base class
- Parameters
- Inputs:
x (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).
attn_mask (Tensor, optional) - Tensor with shape \((sequence\_len, sequence\_len)\) or or \((batch\_size, 1, sequence\_len, sequence\_len)\). Default:
None.key_padding_mask (Tensor, optional) - Tensor with shape \((batch\_size, sequence\_len)\). Default:
None.
- Outputs:
output (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).
Examples
>>> from mindspore import ops >>> from mindscience.models.transformer.attention import Attention >>> model = Attention(in_channels=512, num_heads=4) >>> x = ops.rand((2, 32, 512)) >>> q, k, v = model.get_qkv(x) >>> print(q.shape) (2, 4, 32, 128)