mindscience.e3nn.nn.Gate

class mindscience.e3nn.nn.Gate(irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=mindspore.float32, ncon_dtype=mindspore.float32)[source]

Gate activation function.

The input tensor is conceptually split into three disjoint subsets:

  1. Scalars for activation (irreps_scalars): These scalars are transformed element-wise by the corresponding activation functions in acts, without affecting any other part.

  2. Scalars for gating (irreps_gates): These scalars are transformed element-wise by the corresponding activation functions in act_gates, and then used as gates to modulate the third subset.

  3. Gated irreps (irreps_gated): These irreps (of any angular momentum) are multiplied channel-wise by the gated scalars produced in step 2.

Mathematically, the operation is expressed as

\[\left( \bigoplus_i \phi_i(x_i) \right) \oplus \left( \bigoplus_j \phi_j(g_j)\, y_j \right),\]

where

  • \(x_i\) and \(\phi_i\) correspond to the irreps_scalars and acts,

  • \(g_j\), \(\phi_j\), and \(y_j\) correspond to the irreps_gates, act_gates, and irreps_gated, respectively.

The output irreps are the concatenation of the transformed scalars and the gated irreps, preserving the overall equivariance properties.

Parameters
  • irreps_scalars (Union[str, Irrep, Irreps]) – Scalar irreps to be activated by acts.

  • acts (list[Func]) – Activation functions for each part of irreps_scalars. Length is auto-padded/clipped with identity functions to match irreps_scalars.

  • irreps_gates (Union[str, Irrep, Irreps]) – Scalar irreps to be activated by act_gates and used as gates for irreps_gated.

  • act_gates (list[Func]) – Activation functions for each part of irreps_gates. Length is auto-padded/clipped with identity functions to match irreps_gates.

  • irreps_gated (Union[str, Irrep, Irreps]) – Irreps to be gated.

  • dtype (mindspore.dtype, optional) – Input tensor dtype. Default: mindspore.float32.

  • ncon_dtype (mindspore.dtype, optional) – Dtype for ncom computation. Default: mindspore.float32.

Inputs:
  • input (Tensor) - The shape of Tensor is \((..., irreps\_in.dim)\).

Outputs:
  • output (Tensor) - The shape of Tensor is \((..., irreps\_out.dim)\).

Raises
  • ValueError – If irreps_scalars or irreps_gates contains non-scalar irrep.

  • ValueError – If the total multiplication of irreps_gates does not match the total multiplication of irreps_gated.

Examples

>>> from mindspore import ops
>>> from mindscience.e3nn.nn import Gate
>>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e')
Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e)