mindspore.nn.Unflatten

查看源文件
class mindspore.nn.Unflatten(axis, unflattened_size)[源代码]

根据 axisunflattened_size 折叠指定维度为给定形状。

参数:
  • axis (int) - 指定输入Tensor被折叠维度。

  • unflattened_size (Union(tuple[int], list[int])) - 指定维度维度折叠后的新shape,可以为tuple[int]或者list[int]。 unflattened_size 中各元素的乘积必须等于input_shape[axis]。

输入:
  • input (Tensor) - 进行折叠操作的Tensor。

输出:

折叠操作后的Tensor。

异常:
  • TypeError - axis 不是int。

  • TypeError - unflattened_size 既不是tuple[int]也不是list[int]。

  • TypeError - unflattened_size 中各元素的乘积不等于input_shape[axis]。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, nn
>>> import numpy as np
>>> input = Tensor(np.arange(0, 100).reshape(2, 10, 5), mindspore.float32)
>>> net = nn.Unflatten(1, (2, 5))
>>> output = net(input)
>>> print(f"before unflatten the input shape is {input.shape}")
before unflatten the input shape is  (2, 10, 5)
>>> print(f"after unflatten the output shape is {output.shape}")
after unflatten the output shape is (2, 2, 5, 5)