mindsponge.cell.TriangleMultiplication

class mindsponge.cell.TriangleMultiplication(num_intermediate_channel, equation, layer_norm_dim, batch_size=None)[source]

Triangle multiplication layer. for the detailed implementation process, refer to TriangleMultiplication.

The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, and the result of the dot product between ik and jk is added to the edge of ij.

Parameters
  • num_intermediate_channel (float) – The number of intermediate channel.

  • equation (str) – The equation used in triangle multiplication layer. edge update forms corresponding to ‘incoming’ and ‘outgoing’, \((ikc,jkc->ijc, kjc,kic->ijc)\).

  • layer_norm_dim (int) – The last dimension length of the layer norm.

  • batch_size (int) – The batch size of parameters in triangle multiplication. Default: None.

Inputs:
  • pair_act (Tensor) - Tensor of pair_act. shape \((N{res}, N{res}, layer\_norm\_dim)\).

  • pair_mask (Tensor) - The mask for TriangleAttention matrix with shape. shape \((N{res}, N{res})\).

  • index (Tensor) - The index of while loop, only used in case of while control flow.

Outputs:

Tensor, the float tensor of the pair_act of the layer with shape \((N{res}, N{res}, layer\_norm\_dim)\).

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindsponge.cell import TriangleMultiplication
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = TriangleMultiplication(num_intermediate_channel=64,
...                                equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0)
>>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32)
>>> input_1 = Tensor(np.ones((256, 256)), mstype.float32)
>>> out = model(input_0, input_1, index=0)
>>> print(out.shape)
(256, 256, 64)