mindspore.nn.Tril

class mindspore.nn.Tril[source]

Returns a tensor with elements above the kth diagonal zeroed.

The lower triangular part of the matrix is defined as the elements on and below the diagonal.

The parameter k controls the diagonal to be considered. If diagonal = 0, all elements on and below the main diagonal are retained. Positive values include as many diagonals above the main diagonal, and similarly, negative values exclude as many diagonals below the main diagonal.

Inputs:
  • x (Tensor) - The input tensor. The data type is Number. \((N,*)\) where \(*\) means, any number of additional dimensions.

  • k (Int) - The index of diagonal. Default: 0

Outputs:

Tensor, has the same shape and type as input x.

Raises
Supported Platforms:

Ascend GPU CPU

Examples

>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[ 1  0  0  0]
 [ 5  6  0  0]
 [10 11 12  0]
 [14 15 16 17]]
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 1)
>>> print(result)
[[ 1  2  0  0]
 [ 5  6  7  0]
 [10 11 12 13]
 [14 15 16 17]]
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, 2)
>>> print(result)
[[ 1  2  3  0]
 [ 5  6  7  8]
 [10 11 12 13]
 [14 15 16 17]]
>>> x = Tensor(np.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [10, 11, 12, 13],
...                      [14, 15, 16, 17]]))
>>> tril = nn.Tril()
>>> result = tril(x, -1)
>>> print(result)
[[ 0  0  0  0]
 [ 5  0  0  0]
 [10 11  0  0]
 [14 15 16  0]]