mindchemistry.e3.nn.BatchNorm
- class mindchemistry.e3.nn.BatchNorm(irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False, normalization='component', dtype=float32)[source]
Batch normalization for orthonormal representations. It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations wigner_D are orthonormal.
- Parameters
eps (float) – avoid division by zero when we normalize by the variance. Default:
1e-5
.momentum (float) – momentum of the running average. Default:
0.1
.affine (bool) – do we have weight and bias parameters. Default:
True
.reduce (str) – {'mean', 'max'}, method used to reduce. Default:
'mean'
.instance (bool) – apply instance norm instead of batch norm. Default:
Flase
.normalization (str) – {'component', 'norm'}, normalization method. Default:
'component'
.dtype (mindspore.dtype) – The type of input tensor. Default:
mindspore.float32
.
- Inputs:
input (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).
- Outputs:
output (Tensor) - The shape of Tensor is \((batch, ..., irreps.dim)\).
- Raises
ValueError – If reduce is not in ['mean', 'max'].
ValueError – If normalization is not in ['component', 'norm'].
- Supported Platforms:
Ascend
Examples
>>> from mindchemistry.e3.nn import BatchNorm >>> from mindspore import ops, Tensor >>> bn = BatchNorm('3x0o+2x0e+1x0o') >>> print(bn) BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1) >>> inputs = Tensor(ops.ones((4, 6))) >>> outputs = bn(inputs) >>> print(outputs.shape) (4, 6)