mindspore.ops.BatchMatMul

查看源文件
class mindspore.ops.BatchMatMul(transpose_a=False, transpose_b=False)[源代码]

基于batch维度的两个Tensor的矩阵乘法。

\[\text{output}[..., :, :] = \text{matrix}(x[..., :, :]) * \text{matrix}(y[..., :, :])\]

两个输入Tensor必须具有相同的秩,并且秩必须不小于 2

参数:
  • transpose_a (bool) - 如果为 True ,则在乘法之前转置 x 的最后两个维度。默认值: False

  • transpose_b (bool) - 如果为 True ,则在乘法之前转置 y 的最后两个维度。默认值: False

输入:
  • x (Tensor) - 输入相乘的第一个Tensor。其shape为 \((*B, N, C)\) ,其中 \(*B\) 表示批处理大小,可以是多维度, \(N\)\(C\) 是最后两个维度的大小。如果 transpose_a 为True,则其shape必须为 \((*B, C, N)\)

  • y (Tensor) - 输入相乘的第二个Tensor。Tensor的shape为 \((*B, C, M)\) 。如果 transpose_b 为True,则其shape必须为 \((*B, M, C)\)

输出:

Tensor,输出Tensor的shape为 \((*B, N, M)\)

异常:
  • TypeError - transpose_atranspose_b 不是bool。

  • ValueError - x 的shape长度不等于 y 的shape长度或输入的shape长度小于2。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
>>> y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = ops.BatchMatMul()
>>> output = batmatmul(x, y)
>>> print(output.shape)
(2, 4, 1, 4)
>>> x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
>>> y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = ops.BatchMatMul(transpose_a=True)
>>> output = batmatmul(x, y)
>>> print(output.shape)
(2, 4, 1, 4)