mindspore.nn.MatrixDiagPart

class mindspore.nn.MatrixDiagPart[source]

Returns the batched diagonal part of a batched tensor.

Assume x has \(k\) dimensions \([I, J, K, ..., M, N]\), then the output is a tensor of rank \(k-1\) with dimensions \([I, J, K, ..., min(M, N)]\) where: \(output[i, j, k, ..., n] = x[i, j, k, ..., n, n]\)

Inputs:
  • x (Tensor) - The batched tensor. It can be one of the following data types: float32, float16, int32, int8, and uint8.

Outputs:

Tensor, has the same type as input x. The shape must be x.shape[:-2] + [min(x.shape[-2:])].

Raises

TypeError – If dtype of x is not one of float32, float16, int32, int8 or uint8.

Supported Platforms:

Ascend

Examples

>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> matrix_diag_part = nn.MatrixDiagPart()
>>> output = matrix_diag_part(x)
>>> print(output)
[[-1.  1.]
 [-1.  1.]
 [-1.  1.]]