# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""normact"""
from mindspore import nn, Parameter, float32, ops
from mindspore.common.initializer import initializer
from ..o3.irreps import Irreps
from ..o3.tensor_product import TensorProduct
from ..o3.norm import Norm
[文档]class NormActivation(nn.Cell):
r"""
Activation function for the norm of irreps.
Applies a scalar activation to the norm of each irrep and outputs a (normalized) version of that irrep multiplied
by the scalar output of the scalar activation. Optionally, a learnable bias can be added to the norms before
the activation, and the resulting features can be normalized by their original norm to preserve angular
information while only modulating their magnitude.
Args:
irreps_in (Union[str, Irrep, Irreps]): Input irreps.
act (Func): Activation function applied to the norm of each irrep.
normalize (bool, optional): Whether to normalize input features before multiplying by the scalars from the
nonlinearity. Default: ``True``.
epsilon (float, optional): When ``normalize``, norms smaller than ``epsilon`` are clamped to ``epsilon``
to prevent division by zero. Ignored if ``normalize`` is ``False``. Default: ``None``.
bias (bool, optional): Whether to apply a learnable additive bias to the inputs of ``act``. Default: ``False``.
init_method (Union[str, float, mindspore.common.initializer], optional): Parameter initialization method.
Default: ``'zeros'``.
dtype (mindspore.dtype, optional): Data type of input tensors. Default: ``mindspore.float32``.
ncon_dtype (mindspore.dtype, optional): Data type for ncon computation. Default: ``mindspore.float32``.
Inputs:
- **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`.
Outputs:
- **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`.
Raises:
ValueError: If `epsilon` is not None and `normalize` is False.
ValueError: If `epsilon` is not positive.
Examples:
>>> from mindscience.e3nn.nn import NormActivation
>>> from mindspore import ops, Tensor
>>> norm_activation = NormActivation("2x1e", ops.sigmoid, bias=True)
>>> print(norm_activation)
NormActivation [sigmoid] (2x1e -> 2x1e)
>>> inputs = Tensor(ops.ones((4, 6)))
>>> outputs = norm_activation(inputs)
>>> print(outputs.shape)
(4, 6)
"""
def __init__(self,
irreps_in,
act,
normalize=True,
epsilon=None,
bias=False,
init_method='zeros',
dtype=float32,
ncon_dtype=float32):
super().__init__()
self.irreps_in = Irreps(irreps_in)
self.irreps_out = Irreps(irreps_in)
if epsilon is None and normalize:
epsilon = 1e-8
elif epsilon is not None and not normalize:
raise ValueError("`epsilon` and `normalize = False` don't make sense together.")
elif epsilon is not None and not epsilon > 0:
raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.")
self.epsilon = epsilon
if self.epsilon is not None:
self._eps_squared = epsilon * epsilon
else:
self._eps_squared = 0.0
self.norm = Norm(irreps_in, squared=(epsilon is not None), dtype=dtype)
self.act = act
self.normalize = normalize
if bias:
self.bias = Parameter(initializer(init_method, (self.irreps_in.num_irreps,), dtype),
name=self.__class__.__name__)
else:
self.bias = None
self.scalar_multiplier = TensorProduct(irreps_in1=self.norm.irreps_out,
irreps_in2=irreps_in,
instructions='element',
dtype=dtype,
ncon_dtype=ncon_dtype)
def construct(self, v):
"""Implement the norm-activation function for the input tensor."""
norms = self.norm(v)
if self._eps_squared > 0:
norms[norms < self._eps_squared] = self._eps_squared
norms = ops.sqrt(norms)
nonlin_arg = norms
if self.bias is not None:
nonlin_arg = nonlin_arg + self.bias
scalings = self.act(nonlin_arg)
if self.normalize:
scalings = scalings / norms
return self.scalar_multiplier(scalings, v)
def __repr__(self):
return f"{self.__class__.__name__} [{self.act.__name__}] ({self.irreps_in} -> {self.irreps_in})"