mindscience.e3nn.nn.OneHot

class mindscience.e3nn.nn.OneHot(num_types, dtype=mindspore.float32)[source]

One-hot embedding with irreps support.

The output is automatically wrapped with Irreps to indicate that it transforms as a collection of scalar (\(l = 0\)) representations. This allows the embedding to be used seamlessly in e3nn networks that expect irreps annotations.

Parameters
  • num_types (int) – Number of distinct atom types.

  • dtype (mindspore.dtype, optional) – Data type of the embedding. Default: mindspore.float32.

Inputs:
  • atom_type (Tensor) - Tensor of shape \((...)\), containing integer atom-type indices.

Outputs:
  • output (Tensor) - One-hot tensor of shape \((..., \text{num_types})\).

Examples

>>> from mindscience.e3nn.nn import OneHot
>>> from mindspore import Tensor
>>> one_hot = OneHot(num_types=4)
>>> atom_type = Tensor([0, 2, 1])
>>> out = one_hot(atom_type)
>>> print(out.shape)
(3, 4)