mindscience.e3nn.nn.OneHot

class mindscience.e3nn.nn.OneHot(num_types, dtype=mindspore.float32)[源代码]

支持不可约表示注解的 one-hot 嵌入。 输出会使用 Irreps 标注为标量( \(l = 0\) )表示,可在期望 irreps 注解的 e3nn 网络中直接使用。

参数:
  • num_types (int) - 不同原子类型的数量。

  • dtype (mindspore.dtype,可选) - 嵌入的数据类型。默认值: mindspore.float32

输入:
  • atom_type (Tensor) - 形状为 \((...)\) 的张量,包含整数的原子类型索引。

输出:
  • output (Tensor) - 形状为 \((..., \text{num\_types})\) 的 one-hot 张量。

样例:

>>> from mindscience.e3nn.nn import OneHot
>>> from mindspore import Tensor
>>> one_hot = OneHot(num_types=4)
>>> atom_type = Tensor([0, 2, 1])
>>> out = one_hot(atom_type)
>>> print(out.shape)
(3, 4)