# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import itertools
import collections
import dataclasses
import numpy as np
from mindspore import jit_class, Tensor, ops
from .wigner import wigner_D
from .rotation import matrix_to_angles
from ..utils.func import broadcast_args, _to_tensor, norm_keep, _expand_last_dims, narrow
from ..utils.perm import _inverse
from ..utils.linalg import _direct_sum
# pylint: disable=C0111
[docs]@jit_class
@dataclasses.dataclass(init=False, frozen=True)
class Irrep:
    r"""
    Irreducible representation of O(3). This class does not contain any data, it is a structure that describe the representation.
    It is typically used as argument of other classes of the library to define the input and output representations of functions.
    Args:
        l (Union[int, str]): non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`. Or string to indicate the degree and parity.
        p (int): {1, -1}, the parity of the representation.  Default: ``None``.
    Raises:
        NotImplementedError: If method is not implemented.
        ValueError: If `l` is negative or `p` is not in {1, -1}.
        ValueError: If `l` cannot be converted to an `Irrep`.
        TypeError: If `l` is not int or str.
    Supported Platforms:
        ``Ascend``
    Examples:
        >>> from mindchemistry.e3.o3 import Irrep
        >>> Irrep(0, 1)
        0e
        >>> Irrep("1y")
        1o
        >>> Irrep("2o").dim
        5
        >>> Irrep("2e") in Irrep("1o") * Irrep("1o")
        True
        >>> Irrep("1o") + Irrep("2o")
        1x1o+1x2o
    """
    l: int
    p: int
    def __init__(self, l, p=None):
        if p is None:
            if isinstance(l, Irrep):
                p = l.p
                l = l.l
            if isinstance(l, _MulIr):
                p = l.ir.p
                l = l.ir.l
            if isinstance(l, str):
                try:
                    name = l.strip()
                    l = int(name[:-1])
                    if l < 0:
                        raise ValueError
                    p = {
                        'e': 1,
                        'o': -1,
                        'y': (-1) ** l,
                    }[name[-1]]
                except Exception:
                    raise ValueError
            elif isinstance(l, tuple):
                l, p = l
        if not isinstance(l, int):
            raise TypeError
        elif l < 0:
            raise ValueError
        if p not in [-1, 1]:
            raise ValueError
        object.__setattr__(self, "l", l)
        object.__setattr__(self, "p", p)
    def __repr__(self):
        """Representation of the Irrep."""
        p = {+1: 'e', -1: 'o'}[self.p]
        return f"{self.l}{p}"
    @classmethod
    def iterator(cls, lmax=None):
        for l in itertools.count():
            yield Irrep(l, (-1) ** l)
            yield Irrep(l, -(-1) ** l)
            if l == lmax:
                break
[docs]    def wigD_from_angles(self, alpha, beta, gamma, k=None):
        r"""
        Representation wigner D matrices of O(3) from Euler angles.
        Args:
            alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third.
            beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second.
            gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first.
            k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: ``None`` .
        Returns:
            Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` .
        Examples:
            >>> m = Irrep(1, -1).wigD_from_angles(0, 0 ,0, 1)
            >>> print(m)
            [[-1,  0,  0],
            [ 0, -1,  0],
            [ 0,  0, -1]]
        """
        if k is None:
            k = ops.zeros_like(_to_tensor(alpha))
        alpha, beta, gamma, k = broadcast_args(alpha, beta, gamma, k)
        return wigner_D(self.l, alpha, beta, gamma) * self.p ** _expand_last_dims(k) 
[docs]    def wigD_from_matrix(self, R):
        r"""
        Representation wigner D matrices of O(3) from rotation matrices.
        Args:
            R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`.
        Returns:
            Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`.
        Raises:
            TypeError: If `R` is not a Tensor.
        Examples:
            >>> from mindspore import ops
            >>> m = Irrep(1, -1).wigD_from_matrix(-ops.eye(3))
            >>> print(m)
            [[-1,  0,  0],
            [ 0, -1,  0],
            [ 0,  0, -1]]
        """
        if not isinstance(R, Tensor):
            raise TypeError
        d = Tensor(np.sign(np.linalg.det(R.asnumpy())))
        R = _expand_last_dims(d) * R
        k = (1. - d) / 2
        return self.wigD_from_angles(*matrix_to_angles(R), k) 
    @property
    def dim(self) -> int:
        return 2 * self.l + 1
    def is_scalar(self) -> bool:
        return self.l == 0 and self.p == 1
    def __mul__(self, other):
        r"""
        Generate the irreps from the product of two irreps.
        Returns:
            generator of `Irrep`.
        """
        other = Irrep(other)
        p = self.p * other.p
        lmin = abs(self.l - other.l)
        lmax = self.l + other.l
        for l in range(lmin, lmax + 1):
            yield Irrep(l, p)
    def __rmul__(self, other):
        r"""
        Return `Irreps` of multiple `Irrep`.
        Args:
            other (int): multiple number of the `Irrep`.
        Returns:
            `Irreps` - corresponding multiple `Irrep`.
        Raises:
            TypeError: If `other` is not int.
        """
        if not isinstance(other, int):
            raise TypeError
        return Irreps([(other, self)])
    def __add__(self, other):
        r"""Sum of two irreps."""
        return Irreps(self) + Irreps(other)
    def __radd__(self, other):
        r"""Sum of two irreps."""
        return Irreps(other) + Irreps(self)
    def __iter__(self):
        r"""Deconstruct the irrep into ``l`` and ``p``."""
        yield self.l
        yield self.p
    def __lt__(self, other):
        r"""Compare the order of two irreps."""
        return (self.l, self.p) < (other.l, other.p)
    def __eq__(self, other):
        """Compare two irreps."""
        other = Irrep(other)
        return (self.l, self.p) == (other.l, other.p) 
@jit_class
@dataclasses.dataclass(init=False, frozen=True)
class _MulIr:
    """Multiple Irrep."""
    mul: int
    ir: Irrep
    def __init__(self, mul, ir=None):
        if ir is None:
            mul, ir = mul
        if not (isinstance(mul, int) and isinstance(ir, Irrep)):
            raise TypeError
        object.__setattr__(self, "mul", mul)
        object.__setattr__(self, "ir", ir)
    @property
    def dim(self):
        return self.mul * self.ir.dim
    def __repr__(self):
        """Representation of the irrep."""
        return f"{self.mul}x{self.ir}"
    def __iter__(self):
        """Deconstruct the mulirrep into `mul` and `ir`."""
        yield self.mul
        yield self.ir
    def __lt__(self, other):
        """Compare the order of two mulirreps."""
        return (self.ir, self.mul) < (other.ir, other.mul)
    def __eq__(self, other):
        """Compare two irreps."""
        return (self.mul, self.ir) == (other.mul, other.ir)
[docs]@jit_class
@dataclasses.dataclass(init=False, frozen=False)
class Irreps:
    r"""
    Direct sum of irreducible representations of O(3). This class does not contain any data, it is a structure that describe the representation.
    It is typically used as argument of other classes of the library to define the input and output representations of functions.
    Args:
        irreps (Union[str, Irrep, Irreps, List[Tuple[int]]]): a string to represent the direct sum of irreducible representations.
    Raises:
        ValueError: If `irreps` cannot be converted to an `Irreps`.
        ValueError: If the mul part of `irreps` part is negative.
        TypeError: If the mul part of `irreps` part is not int.
    Supported Platforms:
        ``Ascend``
    Examples:
        >>> from mindchemistry.e3.o3 import Irreps
        >>> x = Irreps([(100, (0, 1)), (50, (1, 1))])
        100x0e+50x1e
        >>> x.dim
        250
        >>> Irreps("100x0e+50x1e+0x2e")
        100x0e+50x1e+0x2e
        >>> Irreps("100x0e+50x1e+0x2e").lmax
        1
        >>> Irrep("2e") in Irreps("0e+2e")
        True
        >>> Irreps(), Irreps("")
        (, )
        >>> Irreps('2x1o+1x0o') * Irreps('2x1o+1x0e')
        4x0e+1x0o+2x1o+4x1e+2x1e+4x2e
    """
    __slots__ = ('data', 'dim', 'slice', 'slice_tuples')
    def __init__(self, irreps=None):
        if isinstance(irreps, Irreps):
            self.data = irreps.data
            self.dim = irreps.dim
            self.slice = irreps.slice
            self.slice_tuples = irreps.slice_tuples
        else:
            out = ()
            if isinstance(irreps, Irrep):
                out += (_MulIr(1, Irrep(irreps)),)
            elif isinstance(irreps, _MulIr):
                out += (irreps,)
            elif isinstance(irreps, str):
                try:
                    if irreps.strip() != "":
                        for mir in irreps.split('+'):
                            if 'x' in mir:
                                mul, ir = mir.split('x')
                                mul = int(mul)
                                ir = Irrep(ir)
                            else:
                                mul = 1
                                ir = Irrep(mir)
                            if not isinstance(mul, int):
                                raise TypeError
                            elif mul < 0:
                                raise ValueError
                            out += (_MulIr(mul, ir),)
                except Exception:
                    raise ValueError
            elif irreps is None:
                pass
            else:
                out = self.handle_irreps(irreps, out)
            self.data = out
            self.dim = self._dim()
            self.slice = self._slices()
            self.slice_tuples = [(s.start, s.stop - s.start) for s in self.slice]
    def handle_irreps(self, irreps, out):
        for mir in irreps:
            if isinstance(mir, str):
                if 'x' in mir:
                    mul, ir = mir.split('x')
                    mul = int(mul)
                    ir = Irrep(ir)
                else:
                    mul = 1
                    ir = Irrep(mir)
            elif isinstance(mir, Irrep):
                mul = 1
                ir = mir
            elif isinstance(mir, _MulIr):
                mul, ir = mir
            elif isinstance(mir, int):
                mul, ir = 1, Irrep(l=mir, p=1)
            elif len(mir) == 2:
                mul, ir = mir
                ir = Irrep(ir)
            if not (isinstance(mul, int) and mul >= 0 and ir is not None):
                raise ValueError
            out += (_MulIr(mul, ir),)
        return out
    def __iter__(self):
        return iter(self.data)
    def __hash__(self):
        return hash(self.data)
    def __len__(self):
        return len(self.data)
    def __repr__(self):
        """Representation of the irreps."""
        return "+".join(f"{mir}" for mir in self.data)
    def __eq__(self, other):
        """Compare two irreps."""
        other = Irreps(other)
        if not len(self) == len(other):
            return False
        for m_1, m_2 in zip(self.data, other.data):
            if not m_1 == m_2:
                return False
        return True
    def __contains__(self, ir):
        """Check if an irrep or an irreps is in the representation."""
        try:
            ir = Irrep(ir)
            return ir in (irrep for _, irrep in self.data)
        except:
            irreps = Irreps(ir)
            m, n = len(irreps), len(self)
            mask = [False] * n
            def dfs(i):
                if i == m:
                    return True
                for j in range(n):
                    if not mask[j]:
                        if irreps.data[i].mul <= self.data[j].mul and irreps.data[i].ir == self.data[j].ir:
                            mask[j] = True
                            found = dfs(i + 1)
                            if found:
                                return True
                            mask[j] = False
                return False
            return dfs(0)
    def __add__(self, irreps):
        irreps = Irreps(irreps)
        return Irreps(self.data.__add__(irreps.data))
    def __mul__(self, other):
        r"""
        Return `Irreps` of multiple `Irreps`.
        Args:
            other (int): multiple number of the `Irreps`.
        Returns:
            `Irreps` - corresponding multiple `Irreps`.
        Raises:
            NotImplementedError: If `other` is `Irreps`, please use `o3.TensorProduct`.
        """
        if isinstance(other, Irreps):
            res = Irreps()
            for mir_1 in self.data:
                for mir_2 in other.data:
                    out_ir = mir_1.ir * mir_2.ir
                    for ir in out_ir:
                        res += mir_1.mul * mir_2.mul * ir
            res, p, _ = res.simplify().sort()
            return res
        return Irreps([(mul * other, ir) for mul, ir in self.data])
    def __rmul__(self, other):
        r"""
        Return repeated `Irreps` of multiple `Irreps`.
        Args:
            other (int): multiple number of the `Irreps`.
        Returns:
            `Irreps` - repeated multiple `Irreps`.
        """
        return self * other
    def _dim(self):
        """The dimension of the representation, :math:`2 l + 1`."""
        return sum(mul * ir.dim for mul, ir in self.data)
    @property
    def num_irreps(self):
        return sum(mul for mul, _ in self.data)
    @property
    def ls(self):
        res = []
        for mul, (l, _) in self.data:
            res.extend([l] * mul)
        return res
    @property
    def lmax(self):
        if len(self) == 0:
            raise ValueError("Cannot get lmax of empty Irreps")
        return max(self.ls)
[docs]    def count(self, ir):
        r"""
        Multiplicity of `ir`.
        Args:
            ir (Irrep): `Irrep`
        Returns:
            int, total multiplicity of `ir`.
        Examples:
            >>> Irreps("1o + 3x2e").count("2e")
            3
        """
        ir = Irrep(ir)
        res = 0
        for mul, irrep in self.data:
            if ir == irrep:
                res += mul
        return res 
[docs]    def simplify(self):
        """
        Simplify the representations.
        Returns:
            `Irreps`
        Examples:
            >>> Irreps("1e + 1e + 0e").simplify()
            2x1e+1x0e
            >>> Irreps("1e + 1e + 0e + 1e").simplify()
            2x1e+1x0e+1x1e
        """
        out = []
        for mul, ir in self.data:
            if out and out[-1][1] == ir:
                out[-1] = (out[-1][0] + mul, ir)
            elif mul > 0:
                out.append((mul, ir))
        return Irreps(out) 
[docs]    def remove_zero_multiplicities(self):
        """
        Remove any irreps with multiplicities of zero.
        Returns:
            `Irreps`
        Examples:
            >>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities()
            4x0e+2x3e
        """
        out = [(mul, ir) for mul, ir in self.data if mul > 0]
        return Irreps(out) 
    def _slices(self):
        r"""
        List of slices corresponding to indices for each irrep.
        Examples:
            >>> Irreps('2x0e + 1e').slices()
            [slice(0, 2, None), slice(2, 5, None)]
        """
        s = []
        i = 0
        for mir in self.data:
            s.append(slice(i, i + mir.dim))
            i += mir.dim
        return s
[docs]    def sort(self):
        r"""
        Sort the representations by increasing degree. 
        Returns:
            irreps (`Irreps`) - sorted `Irreps`
            p (tuple[int]) - permute orders. `p[old_index] = new_index`
            inv (tuple[int]) - inversed permute orders. `p[new_index] = old_index`
        Examples:
            >>> Irreps("1e + 0e + 1e").sort().irreps
            1x0e+1x1e+1x1e
            >>> Irreps("2o + 1e + 0e + 1e").sort().p
            (3, 1, 0, 2)
            >>> Irreps("2o + 1e + 0e + 1e").sort().inv
            (2, 1, 3, 0)
        """
        Ret = collections.namedtuple("sort", ["irreps", "p", "inv"])
        out = [(ir, i, mul) for i, (mul, ir) in enumerate(self.data)]
        out = sorted(out)
        inv = tuple(i for _, i, _ in out)
        p = _inverse(inv)
        irreps = Irreps([(mul, ir) for ir, _, mul in out])
        return Ret(irreps, p, inv) 
[docs]    def filter(self, keep=None, drop=None):
        r"""
        Filter the `Irreps` by either `keep` or `drop`.
        Args:
            keep (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to keep. Default: None.
            drop (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to drop. Default: None.
        Returns:
            `Irreps`, filtered irreps.
        Raises:
            ValueError: If both `keep` and `drop` are not `None`.
        Examples:
            >>> Irreps("1o + 2e").filter(keep="1o")
            1x1o
            >>> Irreps("1o + 2e").filter(drop="1o")
            1x2e
        """
        if keep is None and drop is None:
            return self
        if keep is not None and drop is not None:
            raise ValueError("Cannot specify both keep and drop")
        if keep is not None:
            keep = Irreps(keep).data
            keep = {mir.ir for mir in keep}
            return Irreps([(mul, ir) for mul, ir in self.data if ir in keep])
        if drop is not None:
            drop = Irreps(drop).data
            drop = {mir.ir for mir in drop}
            return Irreps([(mul, ir) for mul, ir in self.data if not ir in drop])
        return None 
[docs]    def decompose(self, v, batch=False):
        r"""
        Decompose a vector by `Irreps`.
        Args:
            v (Tensor): the vector to be decomposed.
            batch (bool): whether reshape the result such that there is at least a batch dimension. Default: `False`.
        Returns:
            List of Tensors, the decomposed vectors by `Irreps`.
        Raises:
            TypeError: If v is not Tensor.
            ValueError: If length of the vector `v` is not matching with dimension of `Irreps`.
        Examples:
            >>> import mindspore as ms
            >>> input = ms.Tensor([1, 2, 3])
            >>> m = Irreps("1o").decompose(input)
            >>> print(m)
            [Tensor(shape=[1,3], dtype=Int64, value=
            [[1,2,3]])]
        """
        if not isinstance(v, Tensor):
            raise TypeError(
                f"The input for decompose should be Tensor, but got {type(v)}.")
        len_v = v.shape[-1]
        if not self.dim == len_v:
            raise ValueError(
                f"the shape of input {v.shape[-1]} do not match irreps dimension {self.dim}.")
        res = []
        batch_shape = v.shape[:-1]
        for (s, l), mir in zip(self.slice_tuples, self.data):
            v_slice = narrow(v, -1, s, l)
            if v.ndim == 1 and batch:
                res.append(v_slice.reshape(
                    (1,) + batch_shape + (mir.mul, mir.ir.dim)))
            else:
                res.append(v_slice.reshape(
                    batch_shape + (mir.mul, mir.ir.dim)))
        return res 
[docs]    @staticmethod
    def spherical_harmonics(lmax, p=-1):
        r"""
        Representation of the spherical harmonics.
        Args:
            lmax (int): maximum of `l`.
            p (int): {1, -1}, the parity of the representation.
        Returns:
            `Irreps`, representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`.
        Examples:
            >>> Irreps.spherical_harmonics(3)
            1x0e+1x1o+1x2e+1x3o
            >>> Irreps.spherical_harmonics(4, p=1)
            1x0e+1x1e+1x2e+1x3e+1x4e
        """
        return Irreps([(1, (l, p ** l)) for l in range(lmax + 1)]) 
[docs]    def randn(self, *size, normalization='component'):
        r"""
        Random tensor.
        Args:
            *size (List[int]): size of the output tensor, needs to contains a `-1`.
            normalization (str): {'component', 'norm'}, type of normalization method.
        Returns:
            Tensor, the shape is `size` where `-1` is replaced by `self.dim`.
        Examples:
            >>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape
            (5, 35, 5)
        """
        di = size.index(-1)
        lsize = size[:di]
        rsize = size[di + 1:]
        if normalization == 'component':
            return ops.standard_normal((*lsize, self.dim, *rsize))
        elif normalization == 'norm':
            x_list = []
            for s, (mul, ir) in zip(self.slice, self.data):
                if mul < 1:
                    continue
                r = ops.standard_normal((*lsize, mul, ir.dim, *rsize))
                r = r / norm_keep(r, axis=di + 1)
                x_list.append(r.reshape((*lsize, -1, *rsize)))
            return ops.concat(x_list, axis=di)
        else:
            raise ValueError("Normalization needs to be 'norm' or 'component'") 
[docs]    def wigD_from_angles(self, alpha, beta, gamma, k=None):
        r"""
        Representation wigner D matrices of O(3) from Euler angles.
        Args:
            alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third.
            beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second.
            gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first.
            k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: None.
        Returns:
            Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`
        Examples:
            >>> m = Irreps("1o").wigD_from_angles(0, 0 ,0, 1)
            >>> print(m)
            [[-1,  0,  0],
            [ 0, -1,  0],
            [ 0,  0, -1]]
        """
        return _direct_sum(*[ir.wigD_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)]) 
[docs]    def wigD_from_matrix(self, R):
        r"""
        Representation wigner D matrices of O(3) from rotation matrices.
        Args:
            R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`.
        Returns:
            Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`
        Raises:
            TypeError: If `R` is not a Tensor.
        Examples:
            >>> m = Irreps("1o").wigD_from_matrix(-ops.eye(3))
            >>> print(m)
            [[-1,  0,  0],
            [ 0, -1,  0],
            [ 0,  0, -1]]
        """
        if not isinstance(R, Tensor):
            raise TypeError
        d = Tensor(np.sign(np.linalg.det(R.asnumpy())))
        R = _expand_last_dims(d) * R
        k = (1 - d) / 2
        return self.wigD_from_angles(*matrix_to_angles(R), k)