mindspore.mint.baddbmm

查看源文件
mindspore.mint.baddbmm(input, batch1, batch2, *, beta=1, alpha=1)[源代码]

batch1batch2 中的矩阵相乘,并与 input 相加。

说明

  • batch1batch2 必须是三维的tensor,且包含相同数量的矩阵。

  • 如果 batch1 是大小为 \((C, W, T)\) 的tensor, batch2 是大小为 \((C, T, H)\) 的tensor, 则 input 必须能够与大小为 \((C, W, H)\) 的tensor进行广播,且输出将是大小为 \((C, W, H)\) 的tensor。

  • beta 为0,那么 input 将会被忽略。

  • 当输入的类型不是 FloatTensor 时,参数 betaalpha 必须是整数。

\[\text{out}_{i} = \beta \text{input}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})\]
参数:
  • input (Tensor) - 输入tensor。

  • batch1 (Tensor) - 第一个batch矩阵。

  • batch2 (Tensor) - 第二个batch矩阵。

关键字参数:
  • beta (Union[float, int], 可选) - input 的尺度因子。默认 1

  • alpha (Union[float, int],可选) - ( batch1 @ batch2 )的尺度因子,默认 1

返回:

Tensor

支持平台:

Ascend

样例:

>>> import mindspore
>>> input = mindspore.mint.ones([3, 3])
>>> batch1 = mindspore.mint.arange(24.0).reshape((2, 3, 4))
>>> batch2 = mindspore.mint.arange(24.0).reshape((2, 4, 3))
>>> mindspore.mint.baddbmm(input, batch1, batch2)
Tensor(shape=[2, 3, 3], dtype=Float32, value=
[[[ 4.30000000e+01,  4.90000000e+01,  5.50000000e+01],
  [ 1.15000000e+02,  1.37000000e+02,  1.59000000e+02],
  [ 1.87000000e+02,  2.25000000e+02,  2.63000000e+02]],
 [[ 9.07000000e+02,  9.61000000e+02,  1.01500000e+03],
  [ 1.17100000e+03,  1.24100000e+03,  1.31100000e+03],
  [ 1.43500000e+03,  1.52100000e+03,  1.60700000e+03]]])