mindchemistry.e3.nn.Gate
- class mindchemistry.e3.nn.Gate(irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=float32, ncon_dtype=float32)[source]
Gate activation function. The input contain three parts: the first part irreps_scalars are scalars that only be affected by activation functions acts; the second part irreps_gates are scalars that be affected by activation functions act_gates and be multiplied on the third part.
\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]where \(x_i\) and \(\phi_i\) are from irreps_scalars and acts, and \(g_j\), \(\phi_j\), and \(y_j\) are from irreps_gates, act_gates, and irreps_gated.
- Parameters
irreps_scalars (Union[str, Irrep, Irreps]) – the input scalar irreps that will be passed through the activation functions acts.
acts (List[Func]) – a list of activation functions for each part of irreps_scalars. The length of the acts will be clipped or filled by identity functions to match the length of irreps_scalars.
irreps_gates (Union[str, Irrep, Irreps]) – the input scalar irreps that will be passed through the activation functions act_gates and multiplied by irreps_gated.
act_gates (List[Func]) – a list of activation functions for each part of irreps_gates. The length of the acts will be clipped or filled by identity functions to match the length of irreps_gates.
irreps_gated (Union[str, Irrep, Irreps]) – the input irreps that will be gated.
dtype (mindspore.dtype) – The type of input tensor. Default:
mindspore.float32
.ncon_dtype (mindspore.dtype) – The type of input tensors of ncon computation module. Default:
mindspore.float32
.
- Inputs:
input (Tensor) - The shape of Tensor is \((..., irreps\_in.dim)\).
- Outputs:
output (Tensor) - The shape of Tensor is \((..., irreps\_out.dim)\).
- Raises
ValueError – If irreps_scalars or irreps_gates contain non-scalar irrep.
ValueError – If the total multiplication of irreps_gates do not match the total multiplication of irreps_gated.
- Supported Platforms:
Ascend
Examples
>>> from mindspore import ops >>> from mindchemistry.e3.nn import Gate >>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e') Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e)