mindchemistry.e3.nn.Gate

查看源文件
class mindchemistry.e3.nn.Gate(irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=float32, ncon_dtype=float32)[源代码]

门控激活函数。输入包含三部分:第一部分 irreps_scalars 是只受激活函数 acts 影响的标量;第二部分 irreps_gates 是受激活函数 act_gates 影响并与第三部分相乘的标量。

\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]

其中 \(x_i\)\(\phi_i\) 来自 irreps_scalarsacts,而 \(g_j\)\(\phi_j\)\(y_j\) 来自 irreps_gatesact_gatesirreps_gated

参数:
  • irreps_scalars (Union[str, Irrep, Irreps]) - 将通过激活函数 acts 的输入标量不可约表示。

  • acts (List[Func]) - 对 irreps_scalars 的每部分应用的激活函数列表。acts 的长度将被剪切或填充为恒等函数,以匹配 irreps_scalars 的长度。

  • irreps_gates (Union[str, Irrep, Irreps]) - 将通过激活函数 act_gates 并与 irreps_gated 相乘的输入标量不可约表示。

  • act_gates (List[Func]) - 每个 irreps_gates 部分的激活函数列表。 acts 的长度将被剪切或填充为恒等函数,以匹配 irreps_gates 的长度。

  • irreps_gated (Union[str, Irrep, Irreps]) - 将被门控的输入不可约表示。

  • dtype (mindspore.dtype) - 输入张量的类型。默认值:mindspore.float32

  • ncon_dtype (mindspore.dtype) - ncon 计算模块输入张量的类型。默认值:mindspore.float32

输入:
  • input (Tensor) - 形状为 \((..., irreps\_in.dim)\) 的张量。

输出:
  • output (Tensor) - 形状为 \((..., irreps\_out.dim)\) 的张量。

异常:
  • ValueError: 如果 irreps_scalarsirreps_gates 包含非标量的不可约表示。

  • ValueError: 如果 irreps_gates 的总乘积不匹配 irreps_gated 的总乘积。

支持平台:

Ascend

样例:

>>> from mindspore import ops
>>> from mindchemistry.e3.nn import Gate
>>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e')
Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e)