mindspore.ops.split

View Source On Gitee
mindspore.ops.split(tensor, split_size_or_sections, axis=0)[source]

Split the tensor into chunks along the given axis.

Parameters
  • tensor (Tensor) – The input tensor.

  • split_size_or_sections (Union[int, tuple(int), list(int)]) – The size of chunks after splited.

  • axis (int, optional) – The axis along which to split. Default 0 .

Note

  • If split_size_or_sections is an int type, the input tensor will be evenly divided into chunks of size split_size_or_sections . The last chunk will have a size equal to the remainder if tensor.shape[axis] is not divisible by split_size_or_sections .

  • If split_size_or_sections is a tuple or list, tensor will be split along axis into len( split_size_or_sections ) chunks with sizes specified by split_size_or_sections .

Returns

Tuple of tensors.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> # case1: `split_size_or_sections` is an int type
>>> input_x = mindspore.ops.arange(10).astype("float32")
>>> output = mindspore.ops.split(tensor=input_x, split_size_or_sections=3)
>>> print(output)
(Tensor(shape=[3], dtype=Float32, value=[0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
 Tensor(shape=[3], dtype=Float32, value=[3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
 Tensor(shape=[3], dtype=Float32, value=[6.00000000e+00, 7.00000000e+00, 8.00000000e+00]),
 Tensor(shape=[1], dtype=Float32, value=[9.00000000e+00]))
>>> # case2: `split_size_or_sections` is a list type
>>> output = mindspore.ops.split(tensor=input_x, split_size_or_sections=[3, 3, 4])
>>> print(output)
(Tensor(shape=[3], dtype=Float32, value=[0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
 Tensor(shape=[3], dtype=Float32, value=[3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
 Tensor(shape=[4], dtype=Float32, value=[6.00000000e+00, 7.00000000e+00, 8.00000000e+00, 9.00000000e+00]))