mindchemistry.e3.nn.Gate

View Source On Gitee
class mindchemistry.e3.nn.Gate(irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=float32, ncon_dtype=float32)[source]

Gate activation function. The input contain three parts: the first part irreps_scalars are scalars that only be affected by activation functions acts; the second part irreps_gates are scalars that be affected by activation functions act_gates and be multiplied on the third part.

\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]

where \(x_i\) and \(\phi_i\) are from irreps_scalars and acts, and \(g_j\), \(\phi_j\), and \(y_j\) are from irreps_gates, act_gates, and irreps_gated.

Parameters
  • irreps_scalars (Union[str, Irrep, Irreps]) – the input scalar irreps that will be passed through the activation functions acts.

  • acts (List[Func]) – a list of activation functions for each part of irreps_scalars. The length of the acts will be clipped or filled by identity functions to match the length of irreps_scalars.

  • irreps_gates (Union[str, Irrep, Irreps]) – the input scalar irreps that will be passed through the activation functions act_gates and multiplied by irreps_gated.

  • act_gates (List[Func]) – a list of activation functions for each part of irreps_gates. The length of the acts will be clipped or filled by identity functions to match the length of irreps_gates.

  • irreps_gated (Union[str, Irrep, Irreps]) – the input irreps that will be gated.

  • dtype (mindspore.dtype) – The type of input tensor. Default: mindspore.float32.

  • ncon_dtype (mindspore.dtype) – The type of input tensors of ncon computation module. Default: mindspore.float32.

Inputs:
  • input (Tensor) - The shape of Tensor is \((..., irreps\_in.dim)\).

Outputs:
  • output (Tensor) - The shape of Tensor is \((..., irreps\_out.dim)\).

Raises
  • ValueError – If irreps_scalars or irreps_gates contain non-scalar irrep.

  • ValueError – If the total multiplication of irreps_gates do not match the total multiplication of irreps_gated.

Supported Platforms:

Ascend

Examples

>>> from mindspore import ops
>>> from mindchemistry.e3.nn import Gate
>>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e')
Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e)