# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import itertools
import collections
import dataclasses
import numpy as np
from mindspore import jit_class, Tensor, ops
from .wigner import wigner_D
from .rotation import matrix_to_angles
from ..utils.func import broadcast_args, _to_tensor, norm_keep, _expand_last_dims, narrow
from ..utils.perm import _inverse
from ..utils.linalg import _direct_sum
# pylint: disable=C0111
[docs]@jit_class
@dataclasses.dataclass(init=False, frozen=True)
class Irrep:
r"""
Irreducible representation of O(3). This class does not contain any data, it is a structure that describe the representation.
It is typically used as argument of other classes of the library to define the input and output representations of functions.
Args:
l (Union[int, str]): non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`. Or string to indicate the degree and parity.
p (int): {1, -1}, the parity of the representation. Default: ``None``.
Raises:
NotImplementedError: If method is not implemented.
ValueError: If `l` is negative or `p` is not in {1, -1}.
ValueError: If `l` cannot be converted to an `Irrep`.
TypeError: If `l` is not int or str.
Supported Platforms:
``Ascend``
Examples:
>>> from mindchemistry.e3.o3 import Irrep
>>> Irrep(0, 1)
0e
>>> Irrep("1y")
1o
>>> Irrep("2o").dim
5
>>> Irrep("2e") in Irrep("1o") * Irrep("1o")
True
>>> Irrep("1o") + Irrep("2o")
1x1o+1x2o
"""
l: int
p: int
def __init__(self, l, p=None):
if p is None:
if isinstance(l, Irrep):
p = l.p
l = l.l
if isinstance(l, _MulIr):
p = l.ir.p
l = l.ir.l
if isinstance(l, str):
try:
name = l.strip()
l = int(name[:-1])
if l < 0:
raise ValueError
p = {
'e': 1,
'o': -1,
'y': (-1) ** l,
}[name[-1]]
except Exception:
raise ValueError
elif isinstance(l, tuple):
l, p = l
if not isinstance(l, int):
raise TypeError
elif l < 0:
raise ValueError
if p not in [-1, 1]:
raise ValueError
object.__setattr__(self, "l", l)
object.__setattr__(self, "p", p)
def __repr__(self):
"""Representation of the Irrep."""
p = {+1: 'e', -1: 'o'}[self.p]
return f"{self.l}{p}"
@classmethod
def iterator(cls, lmax=None):
for l in itertools.count():
yield Irrep(l, (-1) ** l)
yield Irrep(l, -(-1) ** l)
if l == lmax:
break
[docs] def wigD_from_angles(self, alpha, beta, gamma, k=None):
r"""
Representation wigner D matrices of O(3) from Euler angles.
Args:
alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third.
beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second.
gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first.
k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: ``None`` .
Returns:
Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` .
Examples:
>>> m = Irrep(1, -1).wigD_from_angles(0, 0 ,0, 1)
>>> print(m)
[[-1, 0, 0],
[ 0, -1, 0],
[ 0, 0, -1]]
"""
if k is None:
k = ops.zeros_like(_to_tensor(alpha))
alpha, beta, gamma, k = broadcast_args(alpha, beta, gamma, k)
return wigner_D(self.l, alpha, beta, gamma) * self.p ** _expand_last_dims(k)
[docs] def wigD_from_matrix(self, R):
r"""
Representation wigner D matrices of O(3) from rotation matrices.
Args:
R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`.
Returns:
Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`.
Raises:
TypeError: If `R` is not a Tensor.
Examples:
>>> from mindspore import ops
>>> m = Irrep(1, -1).wigD_from_matrix(-ops.eye(3))
>>> print(m)
[[-1, 0, 0],
[ 0, -1, 0],
[ 0, 0, -1]]
"""
if not isinstance(R, Tensor):
raise TypeError
d = Tensor(np.sign(np.linalg.det(R.asnumpy())))
R = _expand_last_dims(d) * R
k = (1. - d) / 2
return self.wigD_from_angles(*matrix_to_angles(R), k)
@property
def dim(self) -> int:
return 2 * self.l + 1
def is_scalar(self) -> bool:
return self.l == 0 and self.p == 1
def __mul__(self, other):
r"""
Generate the irreps from the product of two irreps.
Returns:
generator of `Irrep`.
"""
other = Irrep(other)
p = self.p * other.p
lmin = abs(self.l - other.l)
lmax = self.l + other.l
for l in range(lmin, lmax + 1):
yield Irrep(l, p)
def __rmul__(self, other):
r"""
Return `Irreps` of multiple `Irrep`.
Args:
other (int): multiple number of the `Irrep`.
Returns:
`Irreps` - corresponding multiple `Irrep`.
Raises:
TypeError: If `other` is not int.
"""
if not isinstance(other, int):
raise TypeError
return Irreps([(other, self)])
def __add__(self, other):
r"""Sum of two irreps."""
return Irreps(self) + Irreps(other)
def __radd__(self, other):
r"""Sum of two irreps."""
return Irreps(other) + Irreps(self)
def __iter__(self):
r"""Deconstruct the irrep into ``l`` and ``p``."""
yield self.l
yield self.p
def __lt__(self, other):
r"""Compare the order of two irreps."""
return (self.l, self.p) < (other.l, other.p)
def __eq__(self, other):
"""Compare two irreps."""
other = Irrep(other)
return (self.l, self.p) == (other.l, other.p)
@jit_class
@dataclasses.dataclass(init=False, frozen=True)
class _MulIr:
"""Multiple Irrep."""
mul: int
ir: Irrep
def __init__(self, mul, ir=None):
if ir is None:
mul, ir = mul
if not (isinstance(mul, int) and isinstance(ir, Irrep)):
raise TypeError
object.__setattr__(self, "mul", mul)
object.__setattr__(self, "ir", ir)
@property
def dim(self):
return self.mul * self.ir.dim
def __repr__(self):
"""Representation of the irrep."""
return f"{self.mul}x{self.ir}"
def __iter__(self):
"""Deconstruct the mulirrep into `mul` and `ir`."""
yield self.mul
yield self.ir
def __lt__(self, other):
"""Compare the order of two mulirreps."""
return (self.ir, self.mul) < (other.ir, other.mul)
def __eq__(self, other):
"""Compare two irreps."""
return (self.mul, self.ir) == (other.mul, other.ir)
[docs]@jit_class
@dataclasses.dataclass(init=False, frozen=False)
class Irreps:
r"""
Direct sum of irreducible representations of O(3). This class does not contain any data, it is a structure that describe the representation.
It is typically used as argument of other classes of the library to define the input and output representations of functions.
Args:
irreps (Union[str, Irrep, Irreps, List[Tuple[int]]]): a string to represent the direct sum of irreducible representations.
Raises:
ValueError: If `irreps` cannot be converted to an `Irreps`.
ValueError: If the mul part of `irreps` part is negative.
TypeError: If the mul part of `irreps` part is not int.
Supported Platforms:
``Ascend``
Examples:
>>> from mindchemistry.e3.o3 import Irreps
>>> x = Irreps([(100, (0, 1)), (50, (1, 1))])
100x0e+50x1e
>>> x.dim
250
>>> Irreps("100x0e+50x1e+0x2e")
100x0e+50x1e+0x2e
>>> Irreps("100x0e+50x1e+0x2e").lmax
1
>>> Irrep("2e") in Irreps("0e+2e")
True
>>> Irreps(), Irreps("")
(, )
>>> Irreps('2x1o+1x0o') * Irreps('2x1o+1x0e')
4x0e+1x0o+2x1o+4x1e+2x1e+4x2e
"""
__slots__ = ('data', 'dim', 'slice', 'slice_tuples')
def __init__(self, irreps=None):
if isinstance(irreps, Irreps):
self.data = irreps.data
self.dim = irreps.dim
self.slice = irreps.slice
self.slice_tuples = irreps.slice_tuples
else:
out = ()
if isinstance(irreps, Irrep):
out += (_MulIr(1, Irrep(irreps)),)
elif isinstance(irreps, _MulIr):
out += (irreps,)
elif isinstance(irreps, str):
try:
if irreps.strip() != "":
for mir in irreps.split('+'):
if 'x' in mir:
mul, ir = mir.split('x')
mul = int(mul)
ir = Irrep(ir)
else:
mul = 1
ir = Irrep(mir)
if not isinstance(mul, int):
raise TypeError
elif mul < 0:
raise ValueError
out += (_MulIr(mul, ir),)
except Exception:
raise ValueError
elif irreps is None:
pass
else:
out = self.handle_irreps(irreps, out)
self.data = out
self.dim = self._dim()
self.slice = self._slices()
self.slice_tuples = [(s.start, s.stop - s.start) for s in self.slice]
def handle_irreps(self, irreps, out):
for mir in irreps:
if isinstance(mir, str):
if 'x' in mir:
mul, ir = mir.split('x')
mul = int(mul)
ir = Irrep(ir)
else:
mul = 1
ir = Irrep(mir)
elif isinstance(mir, Irrep):
mul = 1
ir = mir
elif isinstance(mir, _MulIr):
mul, ir = mir
elif isinstance(mir, int):
mul, ir = 1, Irrep(l=mir, p=1)
elif len(mir) == 2:
mul, ir = mir
ir = Irrep(ir)
if not (isinstance(mul, int) and mul >= 0 and ir is not None):
raise ValueError
out += (_MulIr(mul, ir),)
return out
def __iter__(self):
return iter(self.data)
def __hash__(self):
return hash(self.data)
def __len__(self):
return len(self.data)
def __repr__(self):
"""Representation of the irreps."""
return "+".join(f"{mir}" for mir in self.data)
def __eq__(self, other):
"""Compare two irreps."""
other = Irreps(other)
if not len(self) == len(other):
return False
for m_1, m_2 in zip(self.data, other.data):
if not m_1 == m_2:
return False
return True
def __contains__(self, ir):
"""Check if an irrep or an irreps is in the representation."""
try:
ir = Irrep(ir)
return ir in (irrep for _, irrep in self.data)
except:
irreps = Irreps(ir)
m, n = len(irreps), len(self)
mask = [False] * n
def dfs(i):
if i == m:
return True
for j in range(n):
if not mask[j]:
if irreps.data[i].mul <= self.data[j].mul and irreps.data[i].ir == self.data[j].ir:
mask[j] = True
found = dfs(i + 1)
if found:
return True
mask[j] = False
return False
return dfs(0)
def __add__(self, irreps):
irreps = Irreps(irreps)
return Irreps(self.data.__add__(irreps.data))
def __mul__(self, other):
r"""
Return `Irreps` of multiple `Irreps`.
Args:
other (int): multiple number of the `Irreps`.
Returns:
`Irreps` - corresponding multiple `Irreps`.
Raises:
NotImplementedError: If `other` is `Irreps`, please use `o3.TensorProduct`.
"""
if isinstance(other, Irreps):
res = Irreps()
for mir_1 in self.data:
for mir_2 in other.data:
out_ir = mir_1.ir * mir_2.ir
for ir in out_ir:
res += mir_1.mul * mir_2.mul * ir
res, p, _ = res.simplify().sort()
return res
return Irreps([(mul * other, ir) for mul, ir in self.data])
def __rmul__(self, other):
r"""
Return repeated `Irreps` of multiple `Irreps`.
Args:
other (int): multiple number of the `Irreps`.
Returns:
`Irreps` - repeated multiple `Irreps`.
"""
return self * other
def _dim(self):
"""The dimension of the representation, :math:`2 l + 1`."""
return sum(mul * ir.dim for mul, ir in self.data)
@property
def num_irreps(self):
return sum(mul for mul, _ in self.data)
@property
def ls(self):
res = []
for mul, (l, _) in self.data:
res.extend([l] * mul)
return res
@property
def lmax(self):
if len(self) == 0:
raise ValueError("Cannot get lmax of empty Irreps")
return max(self.ls)
[docs] def count(self, ir):
r"""
Multiplicity of `ir`.
Args:
ir (Irrep): `Irrep`
Returns:
int, total multiplicity of `ir`.
Examples:
>>> Irreps("1o + 3x2e").count("2e")
3
"""
ir = Irrep(ir)
res = 0
for mul, irrep in self.data:
if ir == irrep:
res += mul
return res
[docs] def simplify(self):
"""
Simplify the representations.
Returns:
`Irreps`
Examples:
>>> Irreps("1e + 1e + 0e").simplify()
2x1e+1x0e
>>> Irreps("1e + 1e + 0e + 1e").simplify()
2x1e+1x0e+1x1e
"""
out = []
for mul, ir in self.data:
if out and out[-1][1] == ir:
out[-1] = (out[-1][0] + mul, ir)
elif mul > 0:
out.append((mul, ir))
return Irreps(out)
[docs] def remove_zero_multiplicities(self):
"""
Remove any irreps with multiplicities of zero.
Returns:
`Irreps`
Examples:
>>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities()
4x0e+2x3e
"""
out = [(mul, ir) for mul, ir in self.data if mul > 0]
return Irreps(out)
def _slices(self):
r"""
List of slices corresponding to indices for each irrep.
Examples:
>>> Irreps('2x0e + 1e').slices()
[slice(0, 2, None), slice(2, 5, None)]
"""
s = []
i = 0
for mir in self.data:
s.append(slice(i, i + mir.dim))
i += mir.dim
return s
[docs] def sort(self):
r"""
Sort the representations by increasing degree.
Returns:
irreps (`Irreps`) - sorted `Irreps`
p (tuple[int]) - permute orders. `p[old_index] = new_index`
inv (tuple[int]) - inversed permute orders. `p[new_index] = old_index`
Examples:
>>> Irreps("1e + 0e + 1e").sort().irreps
1x0e+1x1e+1x1e
>>> Irreps("2o + 1e + 0e + 1e").sort().p
(3, 1, 0, 2)
>>> Irreps("2o + 1e + 0e + 1e").sort().inv
(2, 1, 3, 0)
"""
Ret = collections.namedtuple("sort", ["irreps", "p", "inv"])
out = [(ir, i, mul) for i, (mul, ir) in enumerate(self.data)]
out = sorted(out)
inv = tuple(i for _, i, _ in out)
p = _inverse(inv)
irreps = Irreps([(mul, ir) for ir, _, mul in out])
return Ret(irreps, p, inv)
[docs] def filter(self, keep=None, drop=None):
r"""
Filter the `Irreps` by either `keep` or `drop`.
Args:
keep (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to keep. Default: None.
drop (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to drop. Default: None.
Returns:
`Irreps`, filtered irreps.
Raises:
ValueError: If both `keep` and `drop` are not `None`.
Examples:
>>> Irreps("1o + 2e").filter(keep="1o")
1x1o
>>> Irreps("1o + 2e").filter(drop="1o")
1x2e
"""
if keep is None and drop is None:
return self
if keep is not None and drop is not None:
raise ValueError("Cannot specify both keep and drop")
if keep is not None:
keep = Irreps(keep).data
keep = {mir.ir for mir in keep}
return Irreps([(mul, ir) for mul, ir in self.data if ir in keep])
if drop is not None:
drop = Irreps(drop).data
drop = {mir.ir for mir in drop}
return Irreps([(mul, ir) for mul, ir in self.data if not ir in drop])
return None
[docs] def decompose(self, v, batch=False):
r"""
Decompose a vector by `Irreps`.
Args:
v (Tensor): the vector to be decomposed.
batch (bool): whether reshape the result such that there is at least a batch dimension. Default: `False`.
Returns:
List of Tensors, the decomposed vectors by `Irreps`.
Raises:
TypeError: If v is not Tensor.
ValueError: If length of the vector `v` is not matching with dimension of `Irreps`.
Examples:
>>> import mindspore as ms
>>> input = ms.Tensor([1, 2, 3])
>>> m = Irreps("1o").decompose(input)
>>> print(m)
[Tensor(shape=[1,3], dtype=Int64, value=
[[1,2,3]])]
"""
if not isinstance(v, Tensor):
raise TypeError(
f"The input for decompose should be Tensor, but got {type(v)}.")
len_v = v.shape[-1]
if not self.dim == len_v:
raise ValueError(
f"the shape of input {v.shape[-1]} do not match irreps dimension {self.dim}.")
res = []
batch_shape = v.shape[:-1]
for (s, l), mir in zip(self.slice_tuples, self.data):
v_slice = narrow(v, -1, s, l)
if v.ndim == 1 and batch:
res.append(v_slice.reshape(
(1,) + batch_shape + (mir.mul, mir.ir.dim)))
else:
res.append(v_slice.reshape(
batch_shape + (mir.mul, mir.ir.dim)))
return res
[docs] @staticmethod
def spherical_harmonics(lmax, p=-1):
r"""
Representation of the spherical harmonics.
Args:
lmax (int): maximum of `l`.
p (int): {1, -1}, the parity of the representation.
Returns:
`Irreps`, representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`.
Examples:
>>> Irreps.spherical_harmonics(3)
1x0e+1x1o+1x2e+1x3o
>>> Irreps.spherical_harmonics(4, p=1)
1x0e+1x1e+1x2e+1x3e+1x4e
"""
return Irreps([(1, (l, p ** l)) for l in range(lmax + 1)])
[docs] def randn(self, *size, normalization='component'):
r"""
Random tensor.
Args:
*size (List[int]): size of the output tensor, needs to contains a `-1`.
normalization (str): {'component', 'norm'}, type of normalization method.
Returns:
Tensor, the shape is `size` where `-1` is replaced by `self.dim`.
Examples:
>>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape
(5, 35, 5)
"""
di = size.index(-1)
lsize = size[:di]
rsize = size[di + 1:]
if normalization == 'component':
return ops.standard_normal((*lsize, self.dim, *rsize))
elif normalization == 'norm':
x_list = []
for s, (mul, ir) in zip(self.slice, self.data):
if mul < 1:
continue
r = ops.standard_normal((*lsize, mul, ir.dim, *rsize))
r = r / norm_keep(r, axis=di + 1)
x_list.append(r.reshape((*lsize, -1, *rsize)))
return ops.concat(x_list, axis=di)
else:
raise ValueError("Normalization needs to be 'norm' or 'component'")
[docs] def wigD_from_angles(self, alpha, beta, gamma, k=None):
r"""
Representation wigner D matrices of O(3) from Euler angles.
Args:
alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third.
beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second.
gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first.
k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: None.
Returns:
Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`
Examples:
>>> m = Irreps("1o").wigD_from_angles(0, 0 ,0, 1)
>>> print(m)
[[-1, 0, 0],
[ 0, -1, 0],
[ 0, 0, -1]]
"""
return _direct_sum(*[ir.wigD_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)])
[docs] def wigD_from_matrix(self, R):
r"""
Representation wigner D matrices of O(3) from rotation matrices.
Args:
R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`.
Returns:
Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`
Raises:
TypeError: If `R` is not a Tensor.
Examples:
>>> m = Irreps("1o").wigD_from_matrix(-ops.eye(3))
>>> print(m)
[[-1, 0, 0],
[ 0, -1, 0],
[ 0, 0, -1]]
"""
if not isinstance(R, Tensor):
raise TypeError
d = Tensor(np.sign(np.linalg.det(R.asnumpy())))
R = _expand_last_dims(d) * R
k = (1 - d) / 2
return self.wigD_from_angles(*matrix_to_angles(R), k)