mindscience.models.transformer.Attention

class mindscience.models.transformer.Attention(in_channels, num_heads, compute_dtype=mstype.float32)[source]

Attention implementation base class

Parameters
  • in_channels (int) – The dimension of input vector.

  • num_heads (int) – The number of attention heads.

  • compute_dtype (mindspore.dtype, optional) – Compute dtype. Default: mstype.float32, indicates mindspore.float32.

Inputs:
  • x (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

  • attn_mask (Tensor, optional) - Tensor with shape \((sequence\_len, sequence\_len)\) or or \((batch\_size, 1, sequence\_len, sequence\_len)\). Default: None.

  • key_padding_mask (Tensor, optional) - Tensor with shape \((batch\_size, sequence\_len)\). Default: None.

Outputs:
  • output (Tensor) - Tensor with shape \((batch\_size, sequence\_len, in\_channels)\).

Examples

>>> from mindspore import ops
>>> from mindscience.models.transformer.attention import Attention
>>> model = Attention(in_channels=512, num_heads=4)
>>> x = ops.rand((2, 32, 512))
>>> q, k, v = model.get_qkv(x)
>>> print(q.shape)
(2, 4, 32, 128)