# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""batchnorm"""
from mindspore import nn, Parameter, ops, float32
from ..o3.irreps import Irreps
[文档]class BatchNorm(nn.Cell):
r"""
Batch normalization for orthonormal representations.
It normalizes by the norm of the representations.
Note that the norm is invariant only for orthonormal representations.
Irreducible representations `wigner_D` are orthonormal.
Args:
irreps (Union[str, Irrep, Irreps]): the input irreps.
eps (float): avoid division by zero when we normalize by the variance. Default: ``1e-5``.
momentum (float): momentum of the running average. Default: ``0.1``.
affine (bool): do we have weight and bias parameters. Default: ``True``.
reduce (str): {'mean', 'max'}, method used to reduce. Default: ``'mean'``.
instance (bool): apply instance norm instead of batch norm. Default: ``Flase``.
normalization (str): {'component', 'norm'}, normalization method. Default: ``'component'``.
dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``.
Inputs:
- **input** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`.
Outputs:
- **output** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`.
Raises:
ValueError: If `reduce` is not in ['mean', 'max'].
ValueError: If `normalization` is not in ['component', 'norm'].
Supported Platforms:
``Ascend``
Examples:
>>> from mindchemistry.e3.nn import BatchNorm
>>> from mindspore import ops, Tensor
>>> bn = BatchNorm('3x0o+2x0e+1x0o')
>>> print(bn)
BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1)
>>> inputs = Tensor(ops.ones((4, 6)))
>>> outputs = bn(inputs)
>>> print(outputs.shape)
(4, 6)
"""
def __init__(self, irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False,
normalization='component', dtype=float32):
super().__init__()
self.irreps = Irreps(irreps)
self.eps = eps
self.momentum = momentum
self.affine = affine
self.instance = instance
self.reduce = reduce
self.normalization = normalization
self.training = True
num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar())
num_features = self.irreps.num_irreps
self.running_mean = None if self.instance else Parameter(ops.zeros(num_scalar, dtype=dtype),
requires_grad=False)
self.running_var = None if self.instance else Parameter(ops.ones(num_features, dtype=dtype),
requires_grad=False)
self.weight = Parameter(ops.ones(num_features, dtype=dtype)) if affine else None
self.bias = Parameter(ops.zeros(num_scalar, dtype=dtype)) if affine else None
def _roll_avg(self, curr, update):
return (1 - self.momentum) * curr + self.momentum * update
def __repr__(self):
return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})"
def construct(self, inputs):
"""construct"""
inputs_shape = inputs.shape
batch = inputs_shape[0]
dim = inputs_shape[-1]
inputs = inputs.reshape(batch, -1, dim)
new_means = []
new_vars = []
fields = []
ix = 0
irm = 0
irv = 0
iw = 0
ib = 0
for mir in self.irreps.data:
mul = mir.mul
ir = mir.ir
d = ir.dim
field = inputs[:, :, ix: ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
# (batch, sample, mul, repr)
field = field.reshape(batch, -1, mul, d)
if ir.is_scalar(): # scalars
if self.training or self.instance:
if self.instance:
field_mean = field.mean(1).reshape(batch, mul) # [batch, mul]
else:
field_mean = field.mean([0, 1]).reshape(mul) # [mul]
new_means.append(
self._roll_avg(self.running_mean[irm:irm + mul], field_mean)
)
else:
field_mean = self.running_mean[irm: irm + mul]
irm += mul
# (batch, sample, mul, repr)
field = field - field_mean.reshape(-1, 1, mul, 1)
if self.training or self.instance:
if self.normalization == 'norm':
field_norm = field.pow(2).sum(3) # [batch, sample, mul]
elif self.normalization == 'component':
field_norm = field.pow(2).mean(3) # [batch, sample, mul]
else:
raise ValueError(f"Invalid normalization option {self.normalization}")
if self.reduce == 'mean':
field_norm = field_norm.mean(1) # [batch, mul]
elif self.reduce == 'max':
field_norm = ops.amax(field_norm, 1) # [batch, mul]
else:
raise ValueError(f"Invalid reduce option {self.reduce}")
if not self.instance:
field_norm = field_norm.mean(0) # [mul]
new_vars.append(self._roll_avg(self.running_var[irv: irv + mul], field_norm))
else:
field_norm = self.running_var[irv: irv + mul]
irv += mul
field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul]
if self.affine:
weight = self.weight[iw: iw + mul] # [mul]
iw += mul
field_norm = field_norm * weight # [(batch,) mul]
field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr]
if self.affine and ir.is_scalar(): # scalars
bias = self.bias[ib: ib + mul] # [mul]
ib += mul
field += bias.reshape(mul, 1) # [batch, sample, mul, repr]
fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr]
if self.training and not self.instance:
ops.assign(self.running_mean, ops.cat(new_means))
ops.assign(self.running_var, ops.cat(new_vars))
output = ops.cat(fields, 2)
return output.reshape(inputs_shape)