mindscience.e3nn.nn.Gate
- class mindscience.e3nn.nn.Gate(irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=mindspore.float32, ncon_dtype=mindspore.float32)[source]
Gate activation function.
The input tensor is conceptually split into three disjoint subsets:
Scalars for activation (irreps_scalars): These scalars are transformed element-wise by the corresponding activation functions in acts, without affecting any other part.
Scalars for gating (irreps_gates): These scalars are transformed element-wise by the corresponding activation functions in act_gates, and then used as gates to modulate the third subset.
Gated irreps (irreps_gated): These irreps (of any angular momentum) are multiplied channel-wise by the gated scalars produced in step 2.
Mathematically, the operation is expressed as
\[\left( \bigoplus_i \phi_i(x_i) \right) \oplus \left( \bigoplus_j \phi_j(g_j)\, y_j \right),\]where
\(x_i\) and \(\phi_i\) correspond to the irreps_scalars and acts,
\(g_j\), \(\phi_j\), and \(y_j\) correspond to the irreps_gates, act_gates, and irreps_gated, respectively.
The output irreps are the concatenation of the transformed scalars and the gated irreps, preserving the overall equivariance properties.
- Parameters
irreps_scalars (Union[str, Irrep, Irreps]) – Scalar irreps to be activated by acts.
acts (list[Func]) – Activation functions for each part of irreps_scalars. Length is auto-padded/clipped with identity functions to match irreps_scalars.
irreps_gates (Union[str, Irrep, Irreps]) – Scalar irreps to be activated by act_gates and used as gates for irreps_gated.
act_gates (list[Func]) – Activation functions for each part of irreps_gates. Length is auto-padded/clipped with identity functions to match irreps_gates.
irreps_gated (Union[str, Irrep, Irreps]) – Irreps to be gated.
dtype (mindspore.dtype, optional) – Input tensor dtype. Default:
mindspore.float32.ncon_dtype (mindspore.dtype, optional) – Dtype for ncom computation. Default:
mindspore.float32.
- Inputs:
input (Tensor) - The shape of Tensor is \((..., irreps\_in.dim)\).
- Outputs:
output (Tensor) - The shape of Tensor is \((..., irreps\_out.dim)\).
- Raises
ValueError – If irreps_scalars or irreps_gates contains non-scalar irrep.
ValueError – If the total multiplication of irreps_gates does not match the total multiplication of irreps_gated.
Examples
>>> from mindspore import ops >>> from mindscience.e3nn.nn import Gate >>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e') Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e)