mindchemistry.e3.nn.BatchNorm

View Source On Gitee
class mindchemistry.e3.nn.BatchNorm(irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False, normalization='component', dtype=float32)[source]

Batch normalization for orthonormal representations. It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations wigner_D are orthonormal.

Parameters
  • irreps (Union[str, Irrep, Irreps]) – the input irreps.

  • eps (float) – avoid division by zero when we normalize by the variance. Default: 1e-5.

  • momentum (float) – momentum of the running average. Default: 0.1.

  • affine (bool) – do we have weight and bias parameters. Default: True.

  • reduce (str) – {'mean', 'max'}, method used to reduce. Default: 'mean'.

  • instance (bool) – apply instance norm instead of batch norm. Default: Flase.

  • normalization (str) – {'component', 'norm'}, normalization method. Default: 'component'.

  • dtype (mindspore.dtype) – The type of input tensor. Default: mindspore.float32.

Inputs:
  • input (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).

Outputs:
  • output (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).

Raises
  • ValueError – If reduce is not in ['mean', 'max'].

  • ValueError – If normalization is not in ['component', 'norm'].

Supported Platforms:

Ascend

Examples

>>> from mindchemistry.e3.nn import BatchNorm
>>> from mindspore import ops, Tensor
>>> bn = BatchNorm('3x0o+2x0e+1x0o')
>>> print(bn)
BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1)
>>> inputs = Tensor(ops.ones((4, 6)))
>>> outputs = bn(inputs)
>>> print(outputs.shape)
(4, 6)