mindscience.e3nn.nn.BatchNorm

class mindscience.e3nn.nn.BatchNorm(irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False, normalization='component', dtype=mindspore.float32)[source]

Batch normalization tailored for orthonormal group representations.

Unlike conventional BatchNorm, this layer normalizes each irreducible representation block by its invariant norm, ensuring equivariance is preserved under group actions such as rotations. Statistics are computed independently per multiplicity block, keeping the tensor structure intact.

The norm is invariant only for orthonormal representations. Irreducible representations wigner_D (and any real basis derived from them) satisfy this requirement, making the layer safe for standard e3nn irreps.

Parameters
  • irreps (Union[str, Irrep, Irreps]) – Input irreps.

  • eps (float, optional) – Small constant to avoid division by zero when normalizing by variance. Default: 1e-5.

  • momentum (float, optional) – Momentum for the running average. Default: 0.1.

  • affine (bool, optional) – Whether to include learnable weight and bias parameters. Default: True.

  • reduce (str, optional) – Reduction method, either 'mean' or 'max'. Default: 'mean'.

  • instance (bool, optional) – If True, apply instance normalization instead of batch normalization. Default: False.

  • normalization (str, optional) – Normalization method, either 'component' or 'norm'. Default: 'component'.

  • dtype (mindspore.dtype, optional) – Data type of the input tensor. Default: mindspore.float32.

Inputs:
  • input (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).

Outputs:
  • output (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).

Raises
  • ValueError – If reduce is not in ['mean', 'max'].

  • ValueError – If normalization is not in ['component', 'norm'].

Examples

>>> from mindscience.e3nn.nn import BatchNorm
>>> from mindspore import ops, Tensor
>>> bn = BatchNorm('3x0o+2x0e+1x0o')
>>> print(bn)
BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1)
>>> inputs = Tensor(ops.ones((4, 6)))
>>> outputs = bn(inputs)
>>> print(outputs.shape)
(4, 6)