mindscience.models.transformer.Attention
- class mindscience.models.transformer.Attention(in_channels, num_heads, compute_dtype=mstype.float32)[源代码]
注意力实现基类
- 参数:
in_channels (int) - 输入向量的维度。
num_heads (int) - 注意力头的数量。
compute_dtype (mindspore.dtype, 可选) - 计算数据类型。默认值:
mstype.float32,表示mindspore.float32。
- 输入:
x (Tensor) - 形状为 \((batch\_size, sequence\_len, in\_channels)\) 的张量。
attn_mask (Tensor, 可选) - 形状为 \((sequence\_len, sequence\_len)\) 或 \((batch\_size, 1, sequence\_len, sequence\_len)\) 的张量。默认值:
None。key_padding_mask (Tensor, 可选) - 形状为 \((batch\_size, sequence\_len)\) 的张量。默认值:
None。
- 输出:
output (Tensor) - 形状为 \((batch\_size, sequence\_len, in\_channels)\) 的张量。
样例:
>>> from mindspore import ops >>> from mindscience.models.transformer.attention import Attention >>> model = Attention(in_channels=512, num_heads=4) >>> x = ops.rand((2, 32, 512)) >>> q, k, v = model.get_qkv(x) >>> print(q.shape) (2, 4, 32, 128)