比较与torch.nn.Module.named_children()的功能差异

查看源文件

torch.nn.Module.named_children

torch.nn.Module.named_children()

更多内容详见torch.nn.Module.named_children

mindspore.nn.Cell.name_cells

mindspore.nn.Cell.name_cells()

更多内容详见mindspore.nn.Cell.name_cells

使用方式

PyTorch:获取网络中的外层子模块的名称和模块,返回类型为迭代器。

MindSpore:获取网络中的外层子模块的名称和模块,返回类型为odict_values。

代码示例

import mindspore as ms
import numpy as np
from mindspore import nn

class ConvBN(nn.Cell):
    def __init__(self):
        super(ConvBN, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.bn = nn.BatchNorm2d(64)
    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MyNet(nn.Cell):
    def __init__(self):
        super(MyNet, self).__init__()
        self.build_block = nn.SequentialCell(ConvBN(), nn.ReLU())
    def construct(self, x):
        return self.build_block(x)

# The following implements mindspore.nn.Cell.name_cells() with MindSpore.
net = MyNet()
print(net.name_cells())
# Out:
OrderedDict([('build_block', SequentialCell<
  (0): ConvBN<
    (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3),stride=(1, 1),  pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
    >
  (1): ReLU<>
  >)])
import torch.nn as nn

class ConvBN(nn.Module):
  def __init__(self):
    super(ConvBN, self).__init__()
    self.conv = nn.Conv2d(3, 64, 3)
    self.bn = nn.BatchNorm2d(64)
  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return x

class MyNet(nn.Module):
  def __init__(self):
    super(MyNet, self).__init__()
    self.build_block = nn.Sequential(ConvBN(), nn.ReLU())
  def construct(self, x):
    return self.build_block(x)

# The following implements torch.nn.Module.named_children() with torch.
net = MyNet()
print(net.named_children(), "\n")
for name, child in net.named_children():
  print("Name: ", name)
  print("Child: ", child)
# Out:
<generator object Module.named_children at 0x7f6a6134abd0>

Name:  build_block
Child:  Sequential(
  (0): ConvBN(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): ReLU()
)