mindspore.ops.Split

查看源文件
class mindspore.ops.Split(axis=0, output_num=1)[源代码]

根据指定的轴和分割数量对输入Tensor进行分割。

更多参考详见 mindspore.ops.split()

参数:
  • axis (int) - 指定分割轴。默认值: 0

  • output_num (int) - 指定分割数量。其值为正整数。默认值: 1

输入:
  • input_x (Tensor) - Tensor的shape为 \((x_0, x_1, ..., x_{R-1})\) ,其中R >= 1。

输出:

tuple[Tensor],每个输出Tensor的shape相同,为 \((x_0, x_1, ..., x_{axis}/{output\_num}, ..., x_{R-1})\) 。数据类型与 input_x 的相同。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> split = ops.Split(1, 2)
>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
>>> print(x)
[[1 1 1 1]
 [2 2 2 2]]
>>> output = split(x)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
 [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
 [2, 2]]))
>>> split = ops.Split(1, 4)
>>> output = split(x)
>>> print(output)
(Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
 [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
 [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
 [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
 [2]]))