mindspore.nn.ChannelShuffle

查看源文件
class mindspore.nn.ChannelShuffle(groups)[源代码]

将shape为 \((*, C, H, W)\) 的Tensor的通道划分成 \(g\) 组,得到shape为 \((*, C \frac g, g, H, W)\) 的Tensor,并沿着 \(C\)\(\frac{g}{}\)\(g\) 对应轴进行转置,将Tensor还原成原有的shape。

参数:
  • groups (int) - 划分通道的组数,必须大于0。在上述公式中表示为 \(g\)

输入:
  • x (Tensor) - Tensor的shape \((*, C_{in}, H_{in}, W_{in})\)

输出:

Tensor,数据类型和shape与 x 相同。

异常:
  • TypeError - groups 非正整数。

  • ValueError - groups 小于1。

  • ValueError - x 的维度小于3。

  • ValueError - x 的通道数不能被 groups 整除。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> import numpy as np
>>> channel_shuffle = ms.nn.ChannelShuffle(2)
>>> x = ms.Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
>>> print(x)
[[[[ 0  1]
   [ 2  3]]
  [[ 4  5]
   [ 6  7]]
  [[ 8  9]
   [10 11]]
  [[12 13]
   [14 15]]]]
>>> output = channel_shuffle(x)
>>> print(output)
[[[[ 0  1]
   [ 2  3]]
  [[ 8  9]
   [10 11]]
  [[ 4  5]
   [ 6  7]]
  [[12 13]
   [14 15]]]]