mindsponge.cell.Attention

class mindsponge.cell.Attention(num_head, hidden_size, gating, q_data_dim, m_data_dim, output_dim, batch_size=None)[source]

This is an implementation of multihead attention in the paper Attention is all you need. Given the query vector with source length, and the key with key length and the target length, the attention will be performed as the following.

\[Attention(query, key, vector) = Concat(head_1, \dots, head_h)W^O\]

where \(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\). The default is with a bias.

if query, key and value tensor is same, then it will be modified version of self attention.

Parameters
  • num_head (int) – The number of the heads.

  • hidden_size (int) – The hidden size of the input.

  • gating (bool) – Indicator of if the attention is gated.

  • q_data_dim (int) – The last dimension length of the query tensor.

  • m_data_dim (int) – The last dimension length of the key and value tensor.

  • output_dim (int) – The last dimension length of the output tensor.

  • batch_size (int) – The batch size of parameters in attention, used in while control flow. Default: None.

Inputs:
  • q_data (Tensor) - The query tensor with shape (batch_size, query_seq_length, q_data_dim) with query_seq_length the query sequence length.

  • m_data (Tensor) - The key/value tensor with shape (batch_size, value_seq_length, m_data_dim) with value_seq_length the value sequence length.

  • attention_mask (Tensor) - The mask for attention matrix with shape (batch_size, num_head, query_seq_length, value_seq_length).

  • index (Tensor) - The index of while loop, only used in case of while control flow. Default: None.

  • nonbatched_bias (Tensor) - Non-batched bias for the attention matrix with shape(num_heads, query_seq_length, value_seq_length). Default: None.

Outputs:

Tensor, output tensor of the Attention layer with shape (batch_size, query_seq_length, hidden_size).

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import Attention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = Attention(num_head=4, hidden_size=64, gating=True, q_data_dim=64,
...                   m_data_dim=64, output_dim=64)
>>> q_data = Tensor(np.ones((32, 128, 64)), mstype.float32)
>>> m_data = Tensor(np.ones((32, 256, 64)), mstype.float32)
>>> attention_mask = Tensor(np.ones((32, 4, 128, 256)), mstype.float32)
>>> attn_out= model(q_data, m_data, attention_mask)
>>> print(attn_out.shape)
(32, 128, 64)