# Copyright 2023-2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"einsum main file"
import math
from collections import defaultdict
from mindspore import mint, nn, mutable
from mindspore import ops as P
from mindspore.common.tensor import Tensor
from mindspore.ops._primitive_cache import _get_cache_prim
from . import constants as C
from .label_order import LabelOrder
from .opt_einusm_path import parse_opt_trace
from .sumproduct_pair import (sumproduct_pair_info, out_cacl_info, rearrange_tensor_to_mul,
rearrange_tensor_to_bmm, rearrange_tensor_to_out, prod_lst)
def _parse_equation(equation: str):
"""
Parse the einsum equation into left-hand side (LHS), right-hand side (RHS), and number of operands.
"""
arrow_pos = equation.find("->")
if arrow_pos == -1:
raise ValueError(f"invalid equation {equation}: require '->'")
equation = equation.replace('...', '.')
arrow_pos = equation.find("->")
lhs = equation[:arrow_pos]
rhs = equation[arrow_pos + 2:]
num_ops = lhs.count(",") + 1
return lhs, rhs, num_ops
def _parse_ellipsis(lhs: str, rhs: str):
"""
Parse the ellipsis dims of equation
"""
op_labels = lhs.split(",") + [rhs]
ellipsis_idxes = []
has_ellipsis = False
for s in op_labels:
ecnt = s.count(".")
if ecnt > 1:
raise ValueError(f"invalid equation {lhs} with multiple '...'")
if ecnt == 1:
pre, post = s.split(".")
ellipsis_idxes.append((len(pre), len(post)))
has_ellipsis = True
else:
ellipsis_idxes.append(None)
if not has_ellipsis:
return None
return ellipsis_idxes
def _sum_dims_helper(a_shape: list, a_sums: tuple[str, ...]):
"""
Helper function to filter out dimensions to be summed and return
the remaining dimensions and their indices.
a_shape: list[tuple[str, int], ...]; like this: [('i', 0), ('j', 1)]
a_sums: tuple[str, ...]):
"""
res = []
sum_dims = []
for i, (k, v) in enumerate(a_shape):
if k not in a_sums:
res.append((k, v))
else:
sum_dims.append(i)
return res, tuple(sum_dims)
def _cacl_mul_reshape(tensor: Tensor, add_dim_info: tuple[int, tuple[int, ...]]):
"""
Calculate the new shape and permutation indices for multiplication operations.
"""
if add_dim_info[0] == 0:
return tensor
add_dims, perm_ids = add_dim_info
added_shape = tensor.shape + (tuple([1]) * add_dims)
new_shape = tuple(added_shape[i] for i in perm_ids)
return tensor.reshape(new_shape)
def _reshape_of_bmm(ta: Tensor, gb: tuple, m: int, k: int, is_trans: bool):
"""
reshape tensor for bmm with BMK or BKM format
"""
new_shape = gb + (k, m) if is_trans else gb + (m, k)
if new_shape != ta.shape:
return ta.reshape(new_shape)
return ta
def _cacl_matmul_reshape(ta, tb, bmm_info):
"""Reshape the tensor for matrix multiplication operations.
Types:
ta: Tensor
tb: Tensor
bmm_info: tuple[bool, bool, bool, tuple[int, ...], tuple[int, ...],
tuple[int, ...], tuple[int, ...]]
"""
a_shape, b_shape = ta.shape, tb.shape
is_batch, transpose_a, transpose_b, a_b, a_m, b_n, a_k = bmm_info
m_dims = tuple(a_shape[d] for d in a_m)
m = prod_lst(m_dims)
n_dims = tuple(b_shape[d] for d in b_n)
n = prod_lst(n_dims)
k = prod_lst(tuple(a_shape[d] for d in a_k))
gb, b_dims = (), ()
if is_batch:
b_dims = tuple(a_shape[d] for d in a_b)
b = prod_lst(b_dims)
gb = (b,)
out_shape = b_dims + m_dims + n_dims
if out_shape == gb + (m, n):
out_shape = None
# transpose_a and left or right in bmm indicate BMK or BKM
ta = _reshape_of_bmm(ta, gb, m, k, transpose_a)
tb = _reshape_of_bmm(tb, gb, n, k, not transpose_b)
return ta, tb, out_shape
def _remove_a_diagonal(labels: str, shape: tuple[int, ...]):
"""
Removes a diagonal element from the labels and shape, ensuring no duplicate labels.
"""
if len(labels) != len(shape):
raise ValueError(f"labels: {labels} and tensor shape: {shape} are different size")
for i in range(len(labels) - 1, 0, -1):
c = labels[i]
idx = labels.find(c, 0, i)
if idx >= 0:
if shape[i] != shape[idx]:
raise ValueError(f"tensor diagonal requires same size, \
while with {shape[i]} and {shape[idx]}")
pairs = [(labels[j], shape[j]) for j in range(len(labels)) if j not in (i, idx)]
new_labels = [a for a, _ in pairs] + [c]
new_shape = tuple(b for _, b in pairs) + (shape[i],)
return (idx, i), "".join(new_labels), new_shape
return None, labels, shape
def _flat_empty_struct(st: list):
"""
Flattens an empty structure to None if it contains no non-empty elements.
"""
for e in st:
if e:
return tuple(st)
return None
def _convert_1_to_2(s: int):
if s == 1:
return 2
return s
def _replace_e1_shape(shapes):
"""Shape equal to 1 will affect preprocessing, use 2 instead.
Replaces all shape elements equal to 1 with 2.
Args:
shapes: list[tuple[int, ...], ...]
"""
res = []
for shape in shapes:
new_shape = tuple(_convert_1_to_2(s) for s in shape)
res.append(new_shape)
return tuple(res)
def _get_ellipsis_shape(shape, label_part: tuple[int, int], elli_shapes: tuple[int, ...]):
"""
replace shape of ellipsis dims
"""
pre_ellipsis, post_ellipsis = label_part
num_dims = len(shape)
total_labels = pre_ellipsis + post_ellipsis
if num_dims < total_labels:
raise ValueError(f"({shape}) is invalid for given equtation, require not less than {total_labels}.")
# The shape of the dimension before '...'
pre_ellipsis_shape = shape[: pre_ellipsis]
# The shape of the dimension after '...'
post_ellipsis_shape = shape[num_dims - post_ellipsis :]
if elli_shapes is not None:
# note: elli_shapes may be tuple([])
new_shape = pre_ellipsis_shape + elli_shapes + post_ellipsis_shape
else:
elli_shapes = tuple(shape[pre_ellipsis: num_dims - post_ellipsis])
new_shape = pre_ellipsis_shape + tuple([prod_lst(elli_shapes)]) + post_ellipsis_shape
return new_shape, elli_shapes
def _update_weight(cacl_info):
"""
update weight by tensor's data volume
"""
for info in cacl_info:
w = max(info["WEIGHT"], C.MIN_WEIGHT_PROD)
info["WEIGHT"] = math.log(w / C.MIN_WEIGHT_PROD, 64) + 1.0
[docs]class Einsum(nn.Cell):
"""Einsum operation using Einstein summation convention.
This operator performs tensor computations using Einstein summation convention (Einsum).
Supports diagonalization, reduction, transposition, matrix multiplication, product operations,
inner products, etc.
Args:
equation (str): Specifies the computation to be performed. Only accepts:
- Letters ([a-z][A-Z]): Represent dimensions of input tensors
- ...: anonymous dimensions
- Commas (','): Separate tensor dimensions
- Arrow ('->'): Left side specifies input tensors, right side specifies desired output dimensions
use_opt (bool, optional): Defaults to ``True``. When set to ``False``, performs contraction path optimization.
Inputs:
- **operands** (List[Tensor]): Variable number of tensor inputs.
Outputs:
- **out_tensor** (Tensor): The result of the einsum operation.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn, Tensor, ops
>>> import numpy as np
>>> import Einsum
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), ms.float32)
>>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), ms.float32)
>>> equation = "ij,jk->ik"
>>> einsum = Einsum(equation, use_opt=False)
>>> output = einsum(x, y)
>>> print(output.shape)
(2, 2)
>>> shapes = [(156, 16, 16), (660, 128, 16), (660, 128, 16)]
>>> x, y, z = [ops.randn(tp) for tp in shapes]
>>> equation = "ijk,zui,zuj->zuk"
>>> einsum = Einsum(equation, use_opt=True)
>>> output = einsum(x, y, z)
"""
def __init__(self, equation, use_opt=True):
"""Initializes the Einsum operator.
This operator performs tensor computations using Einstein summation convention (Einsum).
Supports diagonalization, reduction, transposition, matrix multiplication, product operations,
inner products, etc.
Args:
equation (str): Specifies the computation to be performed. Only accepts:
- Letters ([a-z][A-Z]): Represent dimensions of input tensors
- ...: anonymous dimensions
- Commas (','): Separate tensor dimensions
- Arrow ('->'): Left side specifies input tensors, right side specifies desired output dimensions
use_opt (bool, optional): Defaults to `True`. When set to `False`, performs contraction path optimization.
Raises:
TypeError: If equation is not a string.
"""
super().__init__()
if not isinstance(equation, str):
raise TypeError(f"For einsum, 'equation' must be a str, but got {type(equation)}.")
self.equation = equation.replace(" ", "")
self.lhs, self.rhs, self.num_ops = _parse_equation(self.equation)
self.num_tensors = 2 * self.num_ops - 1
self.contract_dims = self._get_contract_dims()
self.ellipsis_idxes = _parse_ellipsis(self.lhs, self.rhs)
self.use_opt = use_opt
# uninited
self.has_inited = False
self.trace = None
self.order_labels = None
self.diag_ops = None
self.sums_ops, self.perm_ops, self.step_ops = None, None, None
if not use_opt or self.num_ops < 2:
shapes = self._generate_a_random_shape()
self._post_init(shapes)
@staticmethod
def _count_labels(op_labels):
"""
Counts the occurrences of each unique label in the operation labels.
Args:
op_labels: list[str, ...]
Returns:
dict: A dictionary mapping each label to its count.
"""
letter_count = defaultdict(int)
for s in op_labels:
unique_letters = set(s)
for letter in unique_letters:
letter_count[letter] += 1
return dict(letter_count)
@staticmethod
def _bind_shape_with_label(in_shapes, op_labels, rt_list=True):
"""bind shape with label
Args:
in_shapes: tuple[tuple[int, ...], ...]
op_labels: list[str, ...]
rt_list: bool
return example [{'i':2, 'j':3}, {'j':3, 'k':4}, {'k':4, 'i':2}]
"""
bound_shapes = []
for indices, shape in zip(op_labels, in_shapes):
if rt_list:
bound_shape = list(zip(indices, shape))
else:
bound_shape = dict(zip(indices, shape))
bound_shapes.append(bound_shape)
return bound_shapes
def _post_init(self, shapes):
"""
Determine whether it has been initialized. If not, it will be called the first time it runs.
1. Apply path contraction by opt_einsum
2. Apply label order optimization
3. Build calculation steps
"""
base_trace = parse_opt_trace(self.equation, shapes, self.use_opt)
op_labels, self.diag_ops, rm_diag_shapes = self._process_diagonal(shapes)
rm_diag_shapes = _replace_e1_shape(rm_diag_shapes)
tensor_infos = self._build_cacl_steps(rm_diag_shapes, op_labels, base_trace)
base_order = self._get_base_order()
order = LabelOrder(tensor_infos, base_trace, base_order)
self.order_labels, self.trace = order.get_order()
self.sums_ops, self.perm_ops, self.step_ops = self._build(rm_diag_shapes, op_labels, tensor_infos)
self.has_inited = True
def _get_base_order(self):
"""
Generates a base order string by appending characters from lhs to rhs,
excluding commas and duplicates.
Returns:
str: The base order string.
"""
res = self.rhs
for c in self.lhs:
if c != ',' and c not in res:
res += c
return res
def _process_diagonal(self, shapes):
"""
Processes the diagonal elements of the tensors specified by the operation labels.
Args:
shapes: The shapes of the tensors.
list[tuple[int, ...], ...]
Returns:
tuple: A tuple containing the new operation labels, diagonal operations, and new shapes.
"""
op_labels = self.lhs.split(",")
new_op_labels, new_shapes = [], []
diag_ops = []
for op, shape in zip(op_labels, shapes):
diag_pairs = []
while True:
diag_pair, op, shape = _remove_a_diagonal(op, shape)
if not diag_pair:
break
diag_pairs.append(tuple(diag_pair))
diag_ops.append(tuple(diag_pairs))
new_op_labels.append(op)
new_shapes.append(shape)
for _ in range(self.num_ops - 1):
diag_ops.append(None)
new_diag_ops = _flat_empty_struct(diag_ops)
return new_op_labels, new_diag_ops, new_shapes
def _generate_a_random_shape(self):
"""
Generates a random shape for the tensors based on the operation labels.
Returns:
list of tuples: The generated shapes for the tensors.
"""
all_indices = set(self.lhs)
# a random size
dims = {idx: 8 for idx in all_indices}
op_labels = self.lhs.split(",")
input_shapes = []
for labels in op_labels:
shape = tuple(dims[label] for label in labels)
input_shapes.append(shape)
return input_shapes
def _get_contract_dims(self):
"""
Determines the dimensions to be contracted by comparing the sets of lhs and rhs.
Returns:
tuple: The dimensions to be contracted.
"""
set1 = set(self.lhs)
set2 = set(self.rhs + ",")
diff_set = set1 - set2
return tuple(diff_set)
def _build_cacl_steps(self, in_shapes, op_labels, base_trace):
"""
Builds the calculation steps for tensor contractions based on the input shapes and operation labels.
Args:
in_shapes (list of tuples): The shapes of the input tensors.
op_labels: The List of labels; example: ["ijk, "zui", "zuj"]
base_trace: list of Int pair; like [(1, 0), (2, 3)]
Types:
in_shapes: tuple[tuple[int, ...], ...]
op_labels: list[str, ...]):
Returns:
tuple: A tuple of tuples, each containing input and calculation information for each step.
"""
label_counts = Einsum._count_labels(op_labels)
ops = Einsum._bind_shape_with_label(in_shapes, op_labels, rt_list=False)
cacl_info = [None] * self.num_tensors
input_info = []
for labels in op_labels:
input_info.append({"IN": labels, "FROM": C.T_INPUT})
for i, j in base_trace:
a_shape = ops[i]
b_shape = ops[j]
sum_labels = []
a_labels_to_sum, b_labels_to_sum = [], []
for d in self.contract_dims:
if d in a_shape and d in b_shape:
label_counts[d] -= 1
if label_counts[d] == 1:
sum_labels.append(d)
label_counts[d] = 0
elif label_counts[d] == 1:
if d in a_shape:
a_labels_to_sum.append(d)
label_counts[d] = 0
elif d in b_shape:
b_labels_to_sum.append(d)
label_counts[d] = 0
new_shape, a_info, b_info, out_info = sumproduct_pair_info(a_shape, b_shape, a_labels_to_sum,
b_labels_to_sum, sum_labels)
ops.append(new_shape)
input_info.append(out_info)
# dict of calculate info about: matmul or mul
cacl_info[i] = a_info
cacl_info[j] = b_info
cacl_info[-1] = out_cacl_info(ops[self.num_tensors - 1], self.rhs)
_update_weight(cacl_info)
res = tuple(zip(input_info, cacl_info))
return res
def _build(self, in_shapes, op_labels, ops):
"""
Builds the tensor operations and permutations for the given input shapes.
Args:
in_shapes (list of tuples): The shapes of the input tensors.
op_labels (list): The List of labels; example: ["ijk, "zui", "zuj"].
ops (list): result of function _build_cacl_steps.
Types:
in_shapes: tuple[tuple[int, ...], ...]
op_labels: list[str, ...]
ops: tuple[tuple[dict[str, str], ...], ...])
Returns:
tuple: A tuple containing the sum dimensions, permutations, and step operations.
"""
shape_infos = Einsum._bind_shape_with_label(in_shapes, op_labels, rt_list=True)
perm_ops = [None] * self.num_tensors
sums_ops = [None] * self.num_tensors
step_ops = []
for i, j in self.trace:
a_mul_sums, b_mul_sums = ops[i][0].get("SUMS", []), ops[j][0].get("SUMS", [])
a_info, b_info = ops[i][1], ops[j][1]
t_type = a_info["CACL"]
a_shape, a_sum_dims = _sum_dims_helper(shape_infos[i], a_info["SUMS"] + a_mul_sums)
b_shape, b_sum_dims = _sum_dims_helper(shape_infos[j], b_info["SUMS"] + b_mul_sums)
if t_type == C.T_MUL:
a_perm, b_perm, cacl_info, new_shape = rearrange_tensor_to_mul(self.order_labels, a_shape, b_shape)
else:
a_perm, b_perm, cacl_info, new_shape = rearrange_tensor_to_bmm(self.order_labels, a_shape,
a_info, b_shape, b_info)
shape_infos.append(new_shape)
perm_ops[i], perm_ops[j] = a_perm, b_perm
sums_ops[i], sums_ops[j] = a_sum_dims, b_sum_dims
step_ops.append((t_type, cacl_info))
# out
out_shape, out_sum_dims = _sum_dims_helper(shape_infos[self.num_tensors-1], self.contract_dims)
sums_ops[-1] = out_sum_dims
perm_ops[-1] = rearrange_tensor_to_out(out_shape, self.rhs)
sums_ops = _flat_empty_struct(sums_ops)
return sums_ops, tuple(perm_ops), tuple(step_ops)
def _reshape_ellipsis(self, operands):
"""
reshape the dims indicated by ellipses.
"""
if not self.ellipsis_idxes:
return operands, None
new_operands = mutable([])
elli_shapes = None
for i, op in enumerate(operands):
if self.ellipsis_idxes[i]:
new_shape, elli_shapes = _get_ellipsis_shape(op.shape,
self.ellipsis_idxes[i], None)
new_operands.append(op.reshape(new_shape))
else:
new_operands.append(op)
return new_operands, elli_shapes
def _reshape_ellipsis_out(self, out: Tensor, elli_shapes: tuple[int, ...]):
# note: elli_shapes may be tuple([])
if elli_shapes is not None and self.ellipsis_idxes[-1]:
new_shape, _ = _get_ellipsis_shape(out.shape, self.ellipsis_idxes[-1], elli_shapes)
return out.reshape(new_shape)
return out
def _apply_preprocess(self, t, i):
"""
Applies a series of preprocessing operations on the tensor `t` based on the operations
defined in `diag_ops`, `sums_ops`, and `perm_ops`.
Args:
- t (Tensor): The input tensor to be preprocessed.
- i (int): The index used to access the specific operations for this tensor.
Returns:
- Tensor: The preprocessed tensor.
"""
# diagonal
if self.diag_ops and self.diag_ops[i]:
for prev_dim, dim in self.diag_ops[i]:
t = t.diagonal(0, prev_dim, dim)
# sums
if self.sums_ops and self.sums_ops[i]:
t = mint.sum(t, dim=self.sums_ops[i], keepdim=False)
# permute
if self.perm_ops[i]:
t = mint.permute(t, self.perm_ops[i])
return t
def _check_inputargs(self, operands):
"""Check operands."""
if len(operands) != self.num_ops:
raise ValueError("The number of input tensors is inconsistent with the expression.")
for operand in operands:
if not isinstance(operand, Tensor):
raise TypeError(f"For einsum, members of 'operands' must be Tensor, but got {type(operand)}.")
def construct(self, *operands):
"""
Constructs the final output tensor by applying a series of operations defined in `trace` and `step_ops`.
Args:
- *operands: Variable number of input tensors.
Returns:
- Tensor. The final output tensor after applying all the operations.
"""
self._check_inputargs(operands)
operands, elli_shapes = self._reshape_ellipsis(operands)
if not self.has_inited:
shapes = [t.shape for t in operands]
self._post_init(shapes)
data = mutable(list(operands))
for k, tra in enumerate(self.trace):
i, j = tra
t_type, bmm_info = self.step_ops[k]
# Apply preprocessing to the selected tensors
ta = self._apply_preprocess(data[i], i)
tb = self._apply_preprocess(data[j], j)
# Perform the specified operation (mul or matmul)
if t_type == C.T_MUL:
ta = _cacl_mul_reshape(ta, bmm_info[0])
tb = _cacl_mul_reshape(tb, bmm_info[1])
t_out = ta * tb
else:
mm_class = P.BatchMatMul if bmm_info[0] else P.MatMul
matmul = _get_cache_prim(mm_class)(transpose_a=bmm_info[1], transpose_b=bmm_info[2])
ta, tb, out_shape = _cacl_matmul_reshape(ta, tb, bmm_info)
t_out = matmul(ta, tb)
if out_shape:
t_out = t_out.reshape(out_shape)
# append new tensor
data.append(t_out)
# Apply final preprocessing to the last tensor
n = self.num_tensors - 1
out_tensor = self._apply_preprocess(data[n], n)
out_tensor = self._reshape_ellipsis_out(out_tensor, elli_shapes)
return out_tensor