mindchemistry.e3.o3.norm 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""norm"""
from mindspore import nn, ops, float32

from .irreps import Irreps
from .tensor_product import TensorProduct


[文档]class Norm(nn.Cell): r""" Norm of each irrep in a direct sum of irreps. Args: irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. squared (bool): whether to return the squared norm. Default: False. dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. Default: ``mindspore.float32`` . Inputs: - **v** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)` . Outputs: - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_out.dim)` . Supported Platforms: ``Ascend`` Examples: >>> import mindspore as ms >>> import numpy as np >>> from mindchemistry.e3.o3 import Norm >>> n = Norm('3x1o') >>> v = ms.Tensor(np.linspace(1., 2., n.irreps_in.dim), dtype=ms.float32) >>> n(v).shape (1, 3) """ def __init__(self, irreps_in, squared=False, dtype=float32, ncon_dtype=float32): super().__init__() self.squared = squared irreps_in = Irreps(irreps_in).simplify() irreps_out = Irreps([(mul, "0e") for mul, _ in irreps_in]) instr = [(i, i, i, "uuu", False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)] self.tp = TensorProduct(irreps_in, irreps_in, irreps_out, instr, irrep_norm="component", dtype=dtype, ncon_dtype=ncon_dtype) self.irreps_in = irreps_in self.irreps_out = irreps_out.simplify() def construct(self, v): """Implement the norm-activation function for the input tensor.""" out = self.tp(v, v) if self.squared: return out return ops.sqrt(ops.relu(out)) def __repr__(self): return f"{self.__class__.__name__} ({self.irreps_in})"