mindscience.e3nn.nn.normact 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""normact"""
from mindspore import nn, Parameter, float32, ops
from mindspore.common.initializer import initializer

from ..o3.irreps import Irreps
from ..o3.tensor_product import TensorProduct
from ..o3.norm import Norm


[文档]class NormActivation(nn.Cell): r""" Activation function for the norm of irreps. Applies a scalar activation to the norm of each irrep and outputs a (normalized) version of that irrep multiplied by the scalar output of the scalar activation. Optionally, a learnable bias can be added to the norms before the activation, and the resulting features can be normalized by their original norm to preserve angular information while only modulating their magnitude. Args: irreps_in (Union[str, Irrep, Irreps]): Input irreps. act (Func): Activation function applied to the norm of each irrep. normalize (bool, optional): Whether to normalize input features before multiplying by the scalars from the nonlinearity. Default: ``True``. epsilon (float, optional): When ``normalize``, norms smaller than ``epsilon`` are clamped to ``epsilon`` to prevent division by zero. Ignored if ``normalize`` is ``False``. Default: ``None``. bias (bool, optional): Whether to apply a learnable additive bias to the inputs of ``act``. Default: ``False``. init_method (Union[str, float, mindspore.common.initializer], optional): Parameter initialization method. Default: ``'zeros'``. dtype (mindspore.dtype, optional): Data type of input tensors. Default: ``mindspore.float32``. ncon_dtype (mindspore.dtype, optional): Data type for ncon computation. Default: ``mindspore.float32``. Inputs: - **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. Outputs: - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. Raises: ValueError: If `epsilon` is not None and `normalize` is False. ValueError: If `epsilon` is not positive. Examples: >>> from mindscience.e3nn.nn import NormActivation >>> from mindspore import ops, Tensor >>> norm_activation = NormActivation("2x1e", ops.sigmoid, bias=True) >>> print(norm_activation) NormActivation [sigmoid] (2x1e -> 2x1e) >>> inputs = Tensor(ops.ones((4, 6))) >>> outputs = norm_activation(inputs) >>> print(outputs.shape) (4, 6) """ def __init__(self, irreps_in, act, normalize=True, epsilon=None, bias=False, init_method='zeros', dtype=float32, ncon_dtype=float32): super().__init__() self.irreps_in = Irreps(irreps_in) self.irreps_out = Irreps(irreps_in) if epsilon is None and normalize: epsilon = 1e-8 elif epsilon is not None and not normalize: raise ValueError("`epsilon` and `normalize = False` don't make sense together.") elif epsilon is not None and not epsilon > 0: raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.") self.epsilon = epsilon if self.epsilon is not None: self._eps_squared = epsilon * epsilon else: self._eps_squared = 0.0 self.norm = Norm(irreps_in, squared=(epsilon is not None), dtype=dtype) self.act = act self.normalize = normalize if bias: self.bias = Parameter(initializer(init_method, (self.irreps_in.num_irreps,), dtype), name=self.__class__.__name__) else: self.bias = None self.scalar_multiplier = TensorProduct(irreps_in1=self.norm.irreps_out, irreps_in2=irreps_in, instructions='element', dtype=dtype, ncon_dtype=ncon_dtype) def construct(self, v): """Implement the norm-activation function for the input tensor.""" norms = self.norm(v) if self._eps_squared > 0: norms[norms < self._eps_squared] = self._eps_squared norms = ops.sqrt(norms) nonlin_arg = norms if self.bias is not None: nonlin_arg = nonlin_arg + self.bias scalings = self.act(nonlin_arg) if self.normalize: scalings = scalings / norms return self.scalar_multiplier(scalings, v) def __repr__(self): return f"{self.__class__.__name__} [{self.act.__name__}] ({self.irreps_in} -> {self.irreps_in})"