mindsponge.cell.MSAColumnAttention

class mindsponge.cell.MSAColumnAttention(num_head, key_dim, gating, msa_act_dim, batch_size=None, slice_num=0)[源代码]

MSA列注意力层。 MSA逐列注意模块,让处于相同序列位置的信息进行交互。 参考文献:Jumper et al. (2021) Suppl. Alg. 8 MSAColumnAttention

参数:
  • num_head (int) - 头的数量。

  • key_dim (int) - 输入的维度。

  • gating (bool) - 判断attention是否经过gating的指示器。

  • msa_act_dim (int) - msa_act的维度。msa_act为AlphaFold模型中MSA检索后所使用的中间变量。

  • batch_size (int) - MSAColumnAttention中参数的batch size, 默认值:”None”。

  • slice_num (int) - 为了减少内存需要进行切分的数量, 默认值:0。

输入:
  • msa_act (Tensor) - msa_act,AlphaFold模型中MSA检索后所使用的中间变量, \([N_{seqs}, N_{res}, C_m]\)

  • msa_mask (Tensor) - MSAColumnAttention矩阵的mask, \([N_{seqs}, N_{res}]\)

  • index (Tensor) - 在循环中的索引,只会在有控制流的时候使用, 标量, 默认值:”None”。

输出:

Tensor。MSAColumnAttention层的输出msa_act,shape为 \([N_{seqs}, N_{res}, C_m]\)

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> from mindsponge.cell import MSAColumnAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = MSAColumnAttention(num_head=8, key_dim=256, gating=True,
...                         msa_act_dim=256, batch_size=1, slice_num=0)
>>> msa_act = Tensor(np.ones((512, 256, 256)), mstype.float32)
>>> msa_mask = Tensor(np.ones((512, 256)), mstype.float32)
>>> index = Tensor(0, mstype.int32)
>>> attn_out = model(msa_act, msa_mask, index)
>>> print(attn_out.shape)
(512, 256, 256)