mindspore.ops.TriuIndices

查看源文件
class mindspore.ops.TriuIndices(row, col, offset=0, dtype=mstype.int32)[源代码]

计算 row * col 行列矩阵的上三角元素的索引,并将它们作为一个 2xN 的Tensor返回。

警告

这是一个实验性API,后续可能修改或删除。

更多参考详见 mindspore.ops.triu_indices()

参数:
  • row (int) - 2-D 矩阵的行数。

  • col (int) - 2-D 矩阵的列数。

  • offset (int,可选) - 对角线偏移量。默认值: 0

  • dtype (mindspore.dtype,可选) - 指定输出Tensor数据类型,支持的数据类型为 mstype.int32mstype.int64 ,默认值: mstype.int32

输出:
  • y (Tensor) - 矩阵的下三角形部分的索引。数据类型由 dtype 指定,shape为 \((2, tril\_size)\) ,其中,\(tril\_size\) 为上三角矩阵的元素总数。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import ops
>>> from mindspore import dtype as mstype
>>> net = ops.TriuIndices(5, 4, 2, mstype.int64)
>>> output = net()
>>> print(output)
[[0 0 1]
 [2 3 3]]
>>> print(output.dtype)
Int64