Source code for mindscience.e3nn.o3.irreps

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Irreducible representations and their direct sums for O(3)."""

import itertools
import collections
import dataclasses

import numpy as np

from mindspore import jit_class, Tensor, ops

from .wigner import wigner_D
from .rotation import matrix_to_angles
from ..utils.func import broadcast_args, _to_tensor, norm_keep, _expand_last_dims, narrow
from ..utils.perm import _inverse
from ..utils.linalg import _direct_sum

# pylint: disable=C0111

[docs]@jit_class @dataclasses.dataclass(init=False, frozen=True) class Irrep: r""" Irreducible representation of O(3). This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions. The irrep is labeled by a non-negative integer `l` (the degree) and a parity `p` (1 for even, -1 for odd). Common aliases: "e" for even parity, "o" for odd parity, "y" for parity (-1)^l. Args: l (Union[int, str]): non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`. Alternatively, a string such as ``"1o"`` or ``"2e"`` encoding both degree and parity. p (int, optional): the parity of the representation, :math:`p \in \{1, -1\}`. Ignored when ``l`` is a string. Default: ``None``. Raises: NotImplementedError: If method is not implemented. ValueError: If `l` is negative or `p` is not in {1, -1}. ValueError: If `l` cannot be converted to an `Irrep`. TypeError: If `l` is not int or str. Examples: >>> from mindscience.e3nn.o3 import Irrep >>> Irrep(0, 1) 0e >>> Irrep("1y") 1o >>> Irrep("2o").dim 5 >>> Irrep("2e") in Irrep("1o") * Irrep("1o") True >>> Irrep("1o") + Irrep("2o") 1x1o+1x2o """ l: int p: int def __init__(self, l, p=None): if p is None: if isinstance(l, Irrep): p = l.p l = l.l if isinstance(l, _MulIr): p = l.ir.p l = l.ir.l if isinstance(l, str): try: name = l.strip() l = int(name[:-1]) if l < 0: raise ValueError("Irrep degree must be non-negative.") p = { 'e': 1, 'o': -1, 'y': (-1) ** l, }[name[-1]] except Exception as exc: raise ValueError(f"Cannot convert string {l} to Irrep.") from exc elif isinstance(l, tuple): l, p = l if not isinstance(l, int): raise TypeError("Irrep degree must be int.") if l < 0: raise ValueError("Irrep degree must be non-negative.") if p not in [-1, 1]: raise ValueError("Irrep parity must be 1 or -1.") object.__setattr__(self, "l", l) object.__setattr__(self, "p", p) def __repr__(self): """Representation of the Irrep.""" p = {+1: 'e', -1: 'o'}[self.p] return f"{self.l}{p}" @classmethod def iterator(cls, lmax=None): for l in itertools.count(): yield Irrep(l, (-1) ** l) yield Irrep(l, -(-1) ** l) if l == lmax: break
[docs] def wigD_from_angles(self, alpha, beta, gamma, k=None): r""" Compute the Wigner-D matrix representation of O(3) from the three Euler angles :math:`(\alpha, \beta, \gamma)` that describe the rotation sequence: 1. Rotate by :math:`\gamma` around the original Y axis. 2. Rotate by :math:`\beta` around the new X axis. 3. Rotate by :math:`\alpha` around the newest Y axis. Args: alpha (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Rotation :math:`\alpha` around Y axis, applied third. beta (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Rotation :math:`\beta` around X axis, applied second. gamma (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Rotation :math:`\gamma` around Y axis, applied first. k (Union[None, Tensor[float32], list[float], tuple[float], ndarray[np.float32], float], optional): How many times the parity is applied. Default: ``None`` . Returns: Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` . Examples: >>> m = Irrep(1, -1).wigD_from_angles(0, 0 ,0, 1) >>> print(m) [[-1, 0, 0], [ 0, -1, 0], [ 0, 0, -1]] """ if k is None: k = ops.zeros_like(_to_tensor(alpha)) alpha, beta, gamma, k = broadcast_args(alpha, beta, gamma, k) return wigner_D(self.l, alpha, beta, gamma) * self.p ** _expand_last_dims(k)
[docs] def wigD_from_matrix(self, R): r""" Compute the Wigner-D matrix representation of O(3) from rotation matrices. Args: R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. Returns: Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`. Raises: TypeError: If `R` is not a Tensor. Examples: >>> from mindspore import ops >>> m = Irrep(1, -1).wigD_from_matrix(-ops.eye(3)) >>> print(m) [[-1, 0, 0], [ 0, -1, 0], [ 0, 0, -1]] """ if not isinstance(R, Tensor): raise TypeError("R must be a Tensor.") d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) R = _expand_last_dims(d) * R k = (1. - d) / 2 return self.wigD_from_angles(*matrix_to_angles(R), k)
@property def dim(self) -> int: return 2 * self.l + 1
[docs] def is_scalar(self) -> bool: r""" Check whether this irrep is the trivial (scalar) representation. Returns: bool, True if `l = 0` and parity `p = 1`, False otherwise. """ return self.l == 0 and self.p == 1
def __mul__(self, other): r""" Generate the irreps from the product of two irreps. Returns: generator of `Irrep`. """ other = Irrep(other) p = self.p * other.p lmin = abs(self.l - other.l) lmax = self.l + other.l for l in range(lmin, lmax + 1): yield Irrep(l, p) def __rmul__(self, other): r""" Return `Irreps` of multiple `Irrep`. Args: other (int): multiple number of the `Irrep`. Returns: `Irreps`, corresponding multiple `Irrep`. Raises: TypeError: If `other` is not int. """ if not isinstance(other, int): raise TypeError return Irreps([(other, self)]) def __add__(self, other): r"""Sum of two irreps.""" return Irreps(self) + Irreps(other) def __radd__(self, other): r"""Sum of two irreps.""" return Irreps(other) + Irreps(self) def __iter__(self): r"""Deconstruct the irrep into ``l`` and ``p``.""" yield self.l yield self.p def __lt__(self, other): r"""Compare the order of two irreps.""" return (self.l, self.p) < (other.l, other.p) def __eq__(self, other): """Compare two irreps.""" other = Irrep(other) return (self.l, self.p) == (other.l, other.p)
@jit_class @dataclasses.dataclass(init=False, frozen=True) class _MulIr: """Multiple Irrep.""" mul: int ir: Irrep def __init__(self, mul, ir=None): if ir is None: mul, ir = mul if not (isinstance(mul, int) and isinstance(ir, Irrep)): raise TypeError object.__setattr__(self, "mul", mul) object.__setattr__(self, "ir", ir) @property def dim(self): return self.mul * self.ir.dim def __repr__(self): """Representation of the irrep.""" return f"{self.mul}x{self.ir}" def __iter__(self): """Deconstruct the mulirrep into `mul` and `ir`.""" yield self.mul yield self.ir def __lt__(self, other): """Compare the order of two mulirreps.""" return (self.ir, self.mul) < (other.ir, other.mul) def __eq__(self, other): """Compare two irreps.""" return (self.mul, self.ir) == (other.mul, other.ir)
[docs]@jit_class @dataclasses.dataclass(init=False, frozen=False) class Irreps: r""" Direct sum of irreducible representations of O(3). This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions. The irreps are stored as a tuple of (_MulIr) objects, each containing a multiplicity and an Irrep. This allows for easy manipulation, such as addition, multiplication, and filtering of representations. Args: irreps (Union[str, Irrep, Irreps, list[tuple[int]]], optional): A string to represent the direct sum of irreducible representations. Default: ``None``. Raises: ValueError: If `irreps` cannot be converted to an `Irreps`. ValueError: If the mul part of `irreps` part is negative. TypeError: If the mul part of `irreps` part is not int. Examples: >>> from mindscience.e3nn.o3 import Irreps >>> x = Irreps([(100, (0, 1)), (50, (1, 1))]) 100x0e+50x1e >>> x.dim 250 >>> Irreps("100x0e+50x1e+0x2e") 100x0e+50x1e+0x2e >>> Irreps("100x0e+50x1e+0x2e").lmax 1 >>> Irrep("2e") in Irreps("0e+2e") True >>> Irreps(), Irreps("") (, ) >>> Irreps('2x1o+1x0o') * Irreps('2x1o+1x0e') 4x0e+1x0o+2x1o+4x1e+2x1e+4x2e """ __slots__ = ('data', 'dim', 'slice', 'slice_tuples') def __init__(self, irreps=None): if isinstance(irreps, Irreps): self.data = irreps.data self.dim = irreps.dim self.slice = irreps.slice self.slice_tuples = irreps.slice_tuples return out = () if isinstance(irreps, Irrep): out += (_MulIr(1, Irrep(irreps)),) elif isinstance(irreps, _MulIr): out += (irreps,) elif isinstance(irreps, str): try: if irreps.strip() != "": for mir in irreps.split('+'): if 'x' in mir: mul, ir = mir.split('x') mul = int(mul) ir = Irrep(ir) else: mul = 1 ir = Irrep(mir) if not isinstance(mul, int): raise TypeError("Irrep multiplicity must be int.") if mul < 0: raise ValueError( "Irrep multiplicity must be non-negative." ) out += (_MulIr(mul, ir),) except Exception as exc: raise ValueError("Irreps string format is invalid.") from exc elif irreps is None: out = () else: out = self._handle_irreps(irreps, out) self.data = out self.dim = self._dim() self.slice = self._slices() self.slice_tuples = [(s.start, s.stop - s.start) for s in self.slice] def _handle_irreps(self, irreps, out): for mir in irreps: if isinstance(mir, str): if 'x' in mir: mul, ir = mir.split('x') mul = int(mul) ir = Irrep(ir) else: mul = 1 ir = Irrep(mir) elif isinstance(mir, Irrep): mul = 1 ir = mir elif isinstance(mir, _MulIr): mul, ir = mir elif isinstance(mir, int): mul, ir = 1, Irrep(l=mir, p=1) elif len(mir) == 2: mul, ir = mir ir = Irrep(ir) if not (isinstance(mul, int) and mul >= 0 and ir is not None): raise ValueError("Irreps format is invalid.") out += (_MulIr(mul, ir),) return out def __iter__(self): return iter(self.data) def __hash__(self): return hash(self.data) def __len__(self): return len(self.data) def __repr__(self): """Representation of the irreps.""" return "+".join(f"{mir}" for mir in self.data) def __eq__(self, other): """Compare two irreps.""" other = Irreps(other) if not len(self) == len(other): return False for m_1, m_2 in zip(self.data, other.data): if not m_1 == m_2: return False return True def __contains__(self, ir): """Check if an irrep or an irreps is in the representation.""" try: ir = Irrep(ir) return ir in (irrep for _, irrep in self.data) except TypeError: irreps = Irreps(ir) except ValueError: irreps = Irreps(ir) m, n = len(irreps), len(self) mask = [False] * n def dfs(i): if i == m: return True for j in range(n): if not mask[j]: if irreps.data[i].mul <= self.data[j].mul and irreps.data[i].ir == self.data[j].ir: mask[j] = True found = dfs(i + 1) if found: return True mask[j] = False return False return dfs(0) def __add__(self, irreps): irreps = Irreps(irreps) return Irreps(self.data.__add__(irreps.data)) def __mul__(self, other): r""" Return `Irreps` of multiple `Irreps`. Args: other (int): multiple number of the `Irreps`. Returns: `Irreps`, corresponding multiple `Irreps`. Raises: NotImplementedError: If `other` is `Irreps`, please use `o3.TensorProduct`. """ if isinstance(other, Irreps): res = Irreps() for mir_1 in self.data: for mir_2 in other.data: out_ir = mir_1.ir * mir_2.ir for ir in out_ir: res += mir_1.mul * mir_2.mul * ir res, _, _ = res.simplify().sort() return res return Irreps([(mul * other, ir) for mul, ir in self.data]) def __rmul__(self, other): r""" Return repeated `Irreps` of multiple `Irreps`. Args: other (int): multiple number of the `Irreps`. Returns: `Irreps`, repeated multiple `Irreps`. """ return self * other def _dim(self): """The dimension of the representation, :math:`2 l + 1`.""" return sum(mul * ir.dim for mul, ir in self.data) @property def num_irreps(self): return sum(mul for mul, _ in self.data) @property def ls(self): res = [] for mul, (l, _) in self.data: res.extend([l] * mul) return res @property def lmax(self): if len(self) == 0: raise ValueError("Cannot get lmax of empty Irreps") return max(self.ls)
[docs] def count(self, ir): r""" Multiplicity of `ir`. Args: ir (Irrep): Irreducible representation. Returns: int, total multiplicity of `ir`. Examples: >>> Irreps("1o + 3x2e").count("2e") 3 """ ir = Irrep(ir) res = 0 for mul, irrep in self.data: if ir == irrep: res += mul return res
[docs] def simplify(self): """ Simplify the representations. Returns: `Irreps`, simplified `Irreps`. Examples: >>> Irreps("1e + 1e + 0e").simplify() 2x1e+1x0e >>> Irreps("1e + 1e + 0e + 1e").simplify() 2x1e+1x0e+1x1e """ out = [] for mul, ir in self.data: if out and out[-1][1] == ir: out[-1] = (out[-1][0] + mul, ir) elif mul > 0: out.append((mul, ir)) return Irreps(out)
[docs] def remove_zero_multiplicities(self): """ Remove any irreps with multiplicities of zero. Returns: `Irreps`, irreps with multiplicities of zero removed. Examples: >>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities() 4x0e+2x3e """ out = [(mul, ir) for mul, ir in self.data if mul > 0] return Irreps(out)
def _slices(self): r""" List of slices corresponding to indices for each irrep. Examples: >>> Irreps('2x0e + 1e').slices() [slice(0, 2, None), slice(2, 5, None)] """ s = [] i = 0 for mir in self.data: s.append(slice(i, i + mir.dim)) i += mir.dim return s
[docs] def sort(self): r""" Sort the representations by increasing degree. Returns: - `Irreps`, sorted `Irreps`. - `p (tuple[int])`, permute orders, `p[old_index] = new_index`. - `inv (tuple[int])`, inversed permute orders, `p[new_index] = old_index`. Examples: >>> Irreps("1e + 0e + 1e").sort().irreps 1x0e+1x1e+1x1e >>> Irreps("2o + 1e + 0e + 1e").sort().p (3, 1, 0, 2) >>> Irreps("2o + 1e + 0e + 1e").sort().inv (2, 1, 3, 0) """ Ret = collections.namedtuple("sort", ["irreps", "p", "inv"]) out = [(ir, i, mul) for i, (mul, ir) in enumerate(self.data)] out = sorted(out) inv = tuple(i for _, i, _ in out) p = _inverse(inv) irreps = Irreps([(mul, ir) for ir, _, mul in out]) return Ret(irreps, p, inv)
[docs] def filter(self, keep=None, drop=None): r""" Filter the `Irreps` by either `keep` or `drop`. Args: keep (Union[str, Irrep, Irreps, list[str, Irrep]], optional): list of irrep to keep. Default: ``None``. drop (Union[str, Irrep, Irreps, list[str, Irrep]], optional): list of irrep to drop. Default: ``None``. Returns: `Irreps`, filtered irreps. Raises: ValueError: If both `keep` and `drop` are not `None`. Examples: >>> Irreps("1o + 2e").filter(keep="1o") 1x1o >>> Irreps("1o + 2e").filter(drop="1o") 1x2e """ if keep is None and drop is None: return self if keep is not None and drop is not None: raise ValueError("Cannot specify both keep and drop") if keep is not None: keep = Irreps(keep).data keep = {mir.ir for mir in keep} return Irreps([(mul, ir) for mul, ir in self.data if ir in keep]) if drop is not None: drop = Irreps(drop).data drop = {mir.ir for mir in drop} return Irreps([(mul, ir) for mul, ir in self.data if not ir in drop]) return None
[docs] def decompose(self, v, batch=False): r""" Decompose a vector into irreducible components according to the current `Irreps` structure. This method reshapes the last axis of the input tensor `v` such that each slice corresponds to one of the irreducible representations listed in `self`. The resulting list contains one tensor per irrep, with shape `(..., multiplicity, irrep_dimension)`. Args: v (Tensor): the vector to be decomposed. batch (bool, optional): whether reshape the result such that there is at least a batch dimension. Default: ``False``. Returns: List of Tensors, the decomposed vectors by `Irreps`. Raises: TypeError: If v is not Tensor. ValueError: If length of the vector `v` is not matching with dimension of `Irreps`. Examples: >>> import mindspore as ms >>> input = ms.Tensor([1, 2, 3]) >>> m = Irreps("1o").decompose(input) >>> print(m) [Tensor(shape=[1,3], dtype=Int64, value= [[1,2,3]])] """ if not isinstance(v, Tensor): raise TypeError( f"The input for decompose should be Tensor, but got {type(v)}.") len_v = v.shape[-1] if not self.dim == len_v: raise ValueError( f"the shape of input {v.shape[-1]} do not match irreps dimension {self.dim}.") res = [] batch_shape = v.shape[:-1] for (s, l), mir in zip(self.slice_tuples, self.data): v_slice = narrow(v, -1, s, l) if v.ndim == 1 and batch: res.append(v_slice.reshape( (1,) + batch_shape + (mir.mul, mir.ir.dim))) else: res.append(v_slice.reshape( batch_shape + (mir.mul, mir.ir.dim))) return res
[docs] @staticmethod def spherical_harmonics(lmax, p=-1): r""" Representation of the spherical harmonics. Args: lmax (int): maximum of `l`. p (int, optional): {1, -1}, the parity of the representation. Default: ``-1``. Returns: `Irreps`, representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`. Examples: >>> Irreps.spherical_harmonics(3) 1x0e+1x1o+1x2e+1x3o >>> Irreps.spherical_harmonics(4, p=1) 1x0e+1x1e+1x2e+1x3e+1x4e """ return Irreps([(1, (l, p ** l)) for l in range(lmax + 1)])
[docs] def randn(self, *size, normalization='component'): r""" Generate a random tensor whose last dimension matches the total dimension of these irreps. The irreps structure is used to split the last axis into individual irrep blocks, each of which can be normalized either per-component or per-irrep norm. Args: \*size (list[int]): size of the output tensor, needs to contains a `-1`. normalization (str, optional): {'component', 'norm'}, type of normalization method. Default: ``'component'``. Returns: Tensor, the shape is `size` where `-1` is replaced by `self.dim`. Raises: ValueError: If `normalization` is not 'component' or 'norm'. Examples: >>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape (5, 35, 5) """ di = size.index(-1) lsize = size[:di] rsize = size[di + 1:] if normalization == 'component': return ops.standard_normal((*lsize, self.dim, *rsize)) if normalization == 'norm': x_list = [] for _, (mul, ir) in zip(self.slice, self.data): if mul < 1: continue r = ops.standard_normal((*lsize, mul, ir.dim, *rsize)) r = r / norm_keep(r, axis=di + 1) x_list.append(r.reshape((*lsize, -1, *rsize))) return ops.concat(x_list, axis=di) raise ValueError("Normalization needs to be 'norm' or 'component'")
[docs] def wigD_from_angles(self, alpha, beta, gamma, k=None): r""" Compute the Wigner-D matrix representation of O(3) from the three Euler angles :math:`(\alpha, \beta, \gamma)` that describe the rotation sequence: 1. Rotate by :math:`\gamma` around the original Y axis. 2. Rotate by :math:`\beta` around the new X axis. 3. Rotate by :math:`\alpha` around the newest Y axis. The result is the direct sum of the Wigner-D matrices for each irrep contained in this `Irreps` object, repeated according to multiplicity. Args: alpha (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. beta (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. gamma (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. k (Union[None, Tensor[float32], list[float], tuple[float], ndarray[np.float32], float], optional): How many times the parity is applied. Default: ``None``. Returns: Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` . Examples: >>> m = Irreps("1o").wigD_from_angles(0, 0 ,0, 1) >>> print(m) [[-1, 0, 0], [ 0, -1, 0], [ 0, 0, -1]] """ return _direct_sum( *[ ir.wigD_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul) ] )
[docs] def wigD_from_matrix(self, R): r""" Compute Wigner-D matrices of O(3) from rotation matrices. Args: R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. Returns: Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` . Raises: TypeError: If `R` is not a Tensor. Examples: >>> m = Irreps("1o").wigD_from_matrix(-ops.eye(3)) >>> print(m) [[-1, 0, 0], [ 0, -1, 0], [ 0, 0, -1]] """ if not isinstance(R, Tensor): raise TypeError("R needs to be a Tensor") d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) R = _expand_last_dims(d) * R k = (1 - d) / 2 return self.wigD_from_angles(*matrix_to_angles(R), k)