mindsponge.cell.MSAColumnGlobalAttention

class mindsponge.cell.MSAColumnGlobalAttention(num_head, gating, msa_act_dim, batch_size=None, slice_num=0)[source]

MSA column global attention. Transpose MSA information at sequence axis and residue axis, then use GlobalAttention to do Attention between input sequences without dealing with the relationship between residues in sequence. Comparing with MSAColumnAttention, it uses GlobalAttention to deal with longer input sequence.

Reference:

Jumper et al. (2021) Suppl. Alg. 19 ‘MSAColumnGlobalAttention’.

Parameters
  • num_head (int) – The number of the attention heads.

  • gating (bool) – Indicator of if the attention is gated.

  • msa_act_dim (int) – The dimension of the msa_act.

  • batch_size (int) – The batch size of parameters in MSAColumnGlobalAttention, used in while control flow. Default: None.

  • slice_num (int) – The number of slices to be made to reduce memory. Default: 0

Inputs:
  • msa_act (Tensor) - Tensor of msa_act with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

  • msa_mask (Tensor) - The mask for msa_act matrix with shape \((N_{seqs}, N_{res})\) .

  • index (Tensor) - The index of while loop, only used in case of while control flow. Default: “None”.

Outputs:

Tensor, the float tensor of the msa_act of the layer with shape \((N_{seqs}, N_{res}, msa\_act\_dim)\) .

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import MSAColumnGlobalAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = MSAColumnGlobalAttention(num_head=4, gating=True, msa_act_dim=64, batch_size=None)
>>> msa_act = Tensor(np.ones((4, 256, 64)), mstype.float32)
>>> msa_mask = Tensor(np.ones((4, 256)), mstype.float16)
>>> index = None
>>> msa_out = model(msa_act, msa_mask, index)
>>> print(msa_out.shape)
(4, 256, 64)