mindspore.ops.batch_dot

View Source On AtomGit
mindspore.ops.batch_dot(x1, x2, axes=None)[source]

ops.batch_dot is deprecated from version 2.8.0 and will be removed in a future version.

Computation of batch dot product between samples in two tensors containing batch dims.

Note

x1 or x2 first dimension is batch size. Datatype must be float32 and the rank must be greater than or equal to 2.

\[output = x1[batch, :] · x2[batch, :]\]
Parameters
  • x1 (Tensor) – The first input tensor.

  • x2 (Tensor) – The second input tensor.

  • axes (Union[int, tuple(int), list(int)]) – Specify the axes for computation. Default None .

Returns

Tensor

Supported Platforms:

Deprecated

Examples

>>> import mindspore
>>> # case 1: axes is a tuple(axes of `x1` , axes of `x2` )
>>> x1 = mindspore.ops.ones([2, 2, 3])
>>> x2 = mindspore.ops.ones([2, 3, 2])
>>> axes = (-1, -2)
>>> output = mindspore.ops.batch_dot(x1, x2, axes)
>>> print(output)
[[[3. 3.]
  [3. 3.]]
 [[3. 3.]
  [3. 3.]]]
>>> print(output.shape)
(2, 2, 2)
>>> x1 = mindspore.ops.ones([2, 2], mindspore.float32)
>>> x2 = mindspore.ops.ones([2, 3, 2], mindspore.float32)
>>> axes = (1, 2)
>>> output = mindspore.ops.batch_dot(x1, x2, axes)
>>> print(output)
[[2. 2. 2.]
 [2. 2. 2.]]
>>> print(output.shape)
(2, 3)
>>>
>>> # case 2: axes is None
>>> x1 = mindspore.ops.ones([6, 2, 3, 4], mindspore.float32)
>>> x2 = mindspore.ops.ones([6, 5, 4, 8], mindspore.float32)
>>> output = mindspore.ops.batch_dot(x1, x2)
>>> print(output.shape)
(6, 2, 3, 5, 8)
>>>
>>> # case 3: axes is a int data.
>>> x1 = mindspore.ops.ones([2, 2, 4])
>>> x2 = mindspore.ops.ones([2, 5, 4, 5])
>>> output = mindspore.ops.batch_dot(x1, x2, 2)
>>> print(output.shape)
(2, 2, 5, 5)