mindscience.e3nn.so2_conv.so3 源代码

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
so3 file
"""
import mindspore as ms
from mindspore import ops, vmap, jit_class
from mindspore.numpy import tensordot
from .. import o3
from ..o3 import Irreps

from .wigner import wigner_D


[文档]@jit_class class SO3Rotation: """ Class for handling SO(3) rotations of spherical-harmonic irreps. Args: lmax (int): Maximum angular momentum to be considered. irreps_in (Union[str, Irreps]): Input irreps specification. irreps_out (Union[str, Irreps]): Output irreps specification. Examples: >>> from mindscience.e3nn.so2_conv import SO3Rotation >>> rot = SO3Rotation(lmax=2, irreps_in="1x0e + 1x1o", irreps_out="1x1o") >>> wigner, wigner_inv = rot.set_wigner(rot_mat3x3) >>> rotated = rot.rotate(embedding, wigner) """ def __init__(self, lmax, irreps_in, irreps_out): self.lmax = lmax self.irreps_in1 = Irreps(irreps_in) self.irreps_out = Irreps(irreps_out) self.tensordot_vmap = vmap(tensordot, (0, 0, None), 0)
[文档] @staticmethod def narrow(inputs, axis, start, length): """ Narrow (slice) a tensor along a specified axis. Args: inputs (Tensor): The tensor to be sliced. axis (int): The axis along which to perform the slice. start (int): The starting index of the slice. length (int): The number of elements to include in the slice. Returns Tensor, The sliced tensor. """ begins = [0] * inputs.ndim begins[axis] = start sizes = list(inputs.shape) sizes[axis] = length res = ops.slice(inputs, begins, sizes) return res
[文档] @staticmethod def rotation_to_wigner_d_matrix(edge_rot_mat, start_lmax, end_lmax): """ Convert a batch of :math:`3 \times 3` rotation matrices into Wigner-D matrices for the specified range of angular momenta. Args: edge_rot_mat (Tensor): Batch of SO(3) rotation matrices of shape (..., 3, 3). start_lmax (int): Minimum angular momentum to include. end_lmax (int): Maximum angular momentum to include. Returns: list[Tensor], List of Wigner-D matrices for l = start_lmax … end_lmax, each of shape (..., 2l+1, 2l+1). """ x = edge_rot_mat @ ms.Tensor([0.0, 1.0, 0.0]) alpha, beta = o3.xyz_to_angles(x) rvalue = (ops.swapaxes( o3.angles_to_matrix(alpha, beta, ops.zeros_like(alpha)), -1, -2)
[文档] @ edge_rot_mat) gamma = ops.atan2(rvalue[..., 0, 2], rvalue[..., 0, 0]) block_list = [] for lmax in range(start_lmax, end_lmax + 1): block = wigner_D(lmax, alpha, beta, gamma).astype(ms.float32) block_list.append(block) return block_list
def set_wigner(self, rot_mat3x3): """ Compute Wigner-D matrices and their inverses from a batch of :math:`3 \times 3` rotation matrices. Args: rot_mat3x3 (Tensor): Batch of SO(3) rotation matrices of shape (..., 3, 3). Returns: tuple[list[Tensor], list[Tensor]], A tuple containing two lists. - wigner: List of Wigner-D matrices for l = 0 … lmax, each of shape (..., 2l+1, 2l+1). - wigner_inv: List of transposed (inverse) Wigner-D matrices for l = 0 … lmax, same shapes. """ wigner = self.rotation_to_wigner_d_matrix(rot_mat3x3, 0, self.lmax) wigner_inv = [] length = len(wigner) for i in range(length): wigner_inv.append(ops.swapaxes(wigner[i], 1, 2)) return tuple(wigner), tuple(wigner_inv)
[文档] def rotate(self, embedding, wigner): """ Rotate an embedding tensor according to the provided Wigner-D matrices. Args: embedding (Tensor): Input tensor of shape (..., irreps_in.dim) containing the spherical-harmonic coefficients to be rotated. wigner (tuple[Tensor]): Tuple of Wigner-D matrices for l = 0 … lmax, each of shape (..., 2l+1, 2l+1). Returns: tuple[Tensor], Tuple of rotated tensors, one per irrep in irreps_in, each of shape (..., mul, 2l+1). """ res = [] batch_shape = embedding.shape[:-1] for (s, l), mir in zip(self.irreps_in1.slice_tuples, self.irreps_in1.data): v_slice = self.narrow(embedding, -1, s, l) if embedding.ndim == 1: res.append((v_slice.reshape((1,) + batch_shape + (mir.mul, mir.ir.dim)), mir.ir)) else: res.append( (v_slice.reshape(batch_shape + (mir.mul, mir.ir.dim)), mir.ir)) rotate_data_list = [] for data, ir in res: self.tensordot_vmap(data.astype(ms.float16), wigner[ir.l].astype(ms.float16), ([1], [1])) rotate_data = self.tensordot_vmap(data.astype(ms.float16), wigner[ir.l].astype(ms.float16), ((1), (1))).astype(ms.float32) rotate_data_list.append(rotate_data) return tuple(rotate_data_list)
[文档] def rotate_inv(self, embedding, wigner_inv): """ Apply the inverse SO(3) rotation to an embedding tensor using the provided inverse Wigner-D matrices. Args: embedding (tuple[Tensor]): Tuple of tensors, one per irrep in irreps_out, each of shape (..., mul, 2l+1). wigner_inv (tuple[Tensor]): Tuple of inverse (transposed) Wigner-D matrices for l = 0 … lmax, each of shape (..., 2l+1, 2l+1). Returns: Tensor, The rotated-back tensor of shape (..., irreps_out.dim) obtained by concatenating the inverse-rotated irreps. """ res = [] batch_shape = embedding[0].shape[0:1] index = 0 for (_, _), mir in zip(self.irreps_out.slice_tuples, self.irreps_out.data): v_slice = embedding[index] if embedding[0].ndim == 1: res.append((v_slice, mir.ir)) else: res.append((v_slice, mir.ir)) index = index + 1 rotate_back_data_list = [] for data, ir in res: rotate_back_data = self.tensordot_vmap( data.astype(ms.float16), wigner_inv[ir.l].astype(ms.float16), ((1), (1))).astype(ms.float32) rotate_back_data_list.append( rotate_back_data.view(batch_shape + (-1,))) return ops.cat(rotate_back_data_list, -1)