mindspore.ops.FillDiagonal

class mindspore.ops.FillDiagonal(fill_value, wrap=False)[source]

Fills the main diagonal of a Tensor in-place with a specified value and returns the result. The input has at least 2 dimensions, and all dimensions of input must be equal in length when the dimension of input is greater than 2.

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • fill_value (float) – The value to fill the diagonal of input_x.

  • wrap (bool, optional) – Controls whether the diagonal elements continue onto the remaining rows in case of a tall matrix(A matrix has more rows than columns). Examples blow demonstrates how it works on a tall matrix if wrap is set True. Default: False.

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The data type must be float32, int32 or int64.

Outputs:
  • y (Tensor) - Tensor, has the same shape and data type as the input input_x.

Raises
  • TypeError – If data type of input_x is not one of the following: float32, int32, int64.

  • ValueError – If the dimension of input_x is not greater than 1.

  • ValueError – If the size of each dimension is not equal, when the dimension is greater than 2.

Supported Platforms:

Ascend GPU CPU

Examples

>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32))
>>> fill_value = 9.9
>>> fill_diagonal = ops.FillDiagonal(fill_value)
>>> y = fill_diagonal(x)
>>> print(y)
[[9.9 2.  3. ]
 [4.  9.9 6. ]
 [7.  8.  9.9]]
>>> x = Tensor(np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]]).astype(np.int32))
>>> fill_value = 9.0
>>> fill_diagonal = ops.FillDiagonal(fill_value)
>>> y = fill_diagonal(x)
>>> print(y)
[[9 0 0]
 [1 9 1]
 [2 2 9]
 [3 3 3]
 [4 4 4]
 [5 5 5]]
>>> x = Tensor(np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3],
...                      [4, 4, 4], [5, 5, 5], [6, 6, 6]]).astype(np.int64))
>>> fill_value = 9.0
>>> wrap = True
>>> fill_diagonal = FillDiagonal(fill_value, wrap)
>>> y = fill_diagonal(x)
>>> print(y)
[[9 0 0]
 [1 9 1]
 [2 2 9]
 [3 3 3]
 [9 4 4]
 [5 9 5]
 [6 6 9]]