mindscience.e3nn.nn.OneHot
- class mindscience.e3nn.nn.OneHot(num_types, dtype=mindspore.float32)[源代码]
支持不可约表示注解的 one-hot 嵌入。 输出会使用
Irreps标注为标量( \(l = 0\) )表示,可在期望 irreps 注解的 e3nn 网络中直接使用。- 参数:
num_types (int) - 不同原子类型的数量。
dtype (mindspore.dtype,可选) - 嵌入的数据类型。默认值:
mindspore.float32。
- 输入:
atom_type (Tensor) - 形状为 \((...)\) 的张量,包含整数的原子类型索引。
- 输出:
output (Tensor) - 形状为 \((..., \text{num\_types})\) 的 one-hot 张量。
样例:
>>> from mindscience.e3nn.nn import OneHot >>> from mindspore import Tensor >>> one_hot = OneHot(num_types=4) >>> atom_type = Tensor([0, 2, 1]) >>> out = one_hot(atom_type) >>> print(out.shape) (3, 4)