# 网络搭建 [](https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/source_zh_cn/migration_guide/model_development/model_and_cell.md) ## 基础逻辑 PyTorch和MindSpore的基础逻辑如下图所示:  可以看到,PyTorch和MindSpore在实现流程中一般都需要网络定义、正向计算、反向计算、梯度更新等步骤。 - 网络定义:在网络定义中,一般会定义出需要的前向网络,损失函数和优化器。在Net()中定义前向网络,PyTorch的网络继承nn.Module;类似地,MindSpore的网络继承nn.Cell。在MindSpore中,除了使用MindSpore中提供的损失函数和优化器外,用户还可以使用自定义的优化器。可参考[模型模块自定义](https://mindspore.cn/tutorials/zh-CN/r2.3.0rc2/advanced/modules.html)。可以使用functional/nn等接口拼接需要的前向网络、损失函数和优化器。 - 正向计算:运行实例化后的网络,可以得到logit,将logit和target作为输入计算loss。需要注意的是,如果正向计算的函数有多个输出,在反向计算时需要注意多个输出对于计算结果的影响。 - 反向计算:得到loss后,我们可以进行反向计算。在PyTorch中可使用loss.backward()计算梯度,在MindSpore中,先用mindspore.grad()定义出反向传播方程net_backward,再将输入传入net_backward中,即可计算梯度。如果正向计算的函数有多个输出,在反向计算时,可将has_aux设置为True,即可保证只有第一个输出参与求导,其它输出值将直接返回。对于反向计算中接口用法区别详见[自动微分对比](./gradient.md)。 - 梯度更新:将计算后的梯度更新到网络的Parameters中。在PyTorch中使用optim.step();在MindSpore中,将Parameter的梯度传入定义好的optim中,即可完成梯度更新。 ## 网络基本构成单元 Cell MindSpore的网络搭建主要使用[Cell](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell)进行图的构造,用户需要定义一个类继承 `Cell` 这个基类,在 `init` 里声明需要使用的API及子模块,在 `construct` 里进行计算, `Cell` 在 `GRAPH_MODE` (静态图模式)下将编译为一张计算图,在 `PYNATIVE_MODE` (动态图模式)下作为神经网络的基础模块。 PyTorch 和 MindSpore 基本的 `Cell` 搭建过程如下所示:
| PyTorch | MindSpore |
```python
import torch.nn as torch_nn
class MyCell_pt(torch_nn.Module):
def __init__(self, forward_net):
super(MyCell_pt, self).__init__()
self.net = forward_net
self.relu = torch_nn.ReLU()
def forward(self, x):
y = self.net(x)
return self.relu(y)
inner_net_pt = torch_nn.Conv2d(120, 240, kernel_size=4, bias=False)
pt_net = MyCell_pt(inner_net_pt)
for i in pt_net.parameters():
print(i.shape)
```
运行结果:
```text
torch.Size([240, 120, 4, 4])
```
|
```python
import mindspore.nn as nn
import mindspore.ops as ops
class MyCell(nn.Cell):
def __init__(self, forward_net):
super(MyCell, self).__init__(auto_prefix=True)
self.net = forward_net
self.relu = ops.ReLU()
def construct(self, x):
y = self.net(x)
return self.relu(y)
inner_net = nn.Conv2d(120, 240, 4, has_bias=False)
my_net = MyCell(inner_net)
print(my_net.trainable_params())
```
运行结果:
```text
[Parameter (name=net.weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
```
|
| PyTorch 设置模型数据类型 | MindSpore 设置模型数据类型 |
```python
import torch
import torch.nn as nn
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 12, kernel_size=3, padding=1),
nn.BatchNorm2d(12),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(12, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.pool = nn.AdaptiveMaxPool2d((5, 5))
self.fc = nn.Linear(100, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
net = Network()
net = net.to(torch.float32)
for name, module in net.named_modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
module.to(torch.float32)
loss = nn.CrossEntropyLoss(reduction='mean')
loss = loss.to(torch.float32)
```
|
```python
import mindspore as ms
from mindspore import nn
# 定义模型
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.layer1 = nn.SequentialCell([
nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1),
nn.BatchNorm2d(12),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
])
self.layer2 = nn.SequentialCell([
nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
])
self.pool = nn.AdaptiveMaxPool2d((5, 5))
self.fc = nn.Dense(100, 10)
def construct(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x)
x = x.view((-1, 100))
out = nn.Dense(x)
return out
net = Network()
net.to_float(ms.float16) # 将net里所有的操作加float16的标志,框架会在编译时在输入加cast方法
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
cell.to_float(ms.float32)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean').to_float(ms.float32)
net_with_loss = nn.WithLossCell(net, loss_fn=loss)
```
|
| PyTorch | MindSpore |
```python
import torch.nn as nn
net = nn.Linear(2, 1)
for name, param in net.named_parameters():
print("Parameter Name:", name)
for name, param in net.named_parameters():
if "bias" in name:
param.requires_grad = False
for name, param in net.named_parameters():
if param.requires_grad:
print("Parameter Name:", name)
```
运行结果:
```text
Parameter Name: weight
Parameter Name: bias
Parameter Name: weight
```
|
```python
import mindspore.nn as nn
net = nn.Dense(2, 1, has_bias=True)
print(net.trainable_params())
for param in net.trainable_params():
param_name = param.name
if "bias" in param_name:
param.requires_grad = False
print(net.trainable_params())
```
运行结果:
```text
[Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)]
[Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True)]
```
|
| PyTorch | MindSpore |
```python
import torch
import torch.nn as nn
linear_layer = nn.Linear(2, 1, bias=True)
linear_layer.weight.data.fill_(1.0)
linear_layer.bias.data.zero_()
print("Original linear layer parameters:")
print(linear_layer.weight)
print(linear_layer.bias)
torch.save(linear_layer.state_dict(), 'linear_layer_params.pth')
new_linear_layer = nn.Linear(2, 1, bias=True)
new_linear_layer.load_state_dict(torch.load('linear_layer_params.pth'))
# 打印加载后的Parameter,应该和原始Parameter一样
print("Loaded linear layer parameters:")
print(new_linear_layer.weight)
print(new_linear_layer.bias)
```
运行结果:
```text
Original linear layer parameters:
Parameter containing:
tensor([[1., 1.]], requires_grad=True)
Parameter containing:
tensor([0.], requires_grad=True)
Loaded linear layer parameters:
Parameter containing:
tensor([[1., 1.]], requires_grad=True)
Parameter containing:
tensor([0.], requires_grad=True)
```
|
```python
import mindspore as ms
import mindspore.ops as ops
import mindspore.nn as nn
net = nn.Dense(2, 1, has_bias=True)
for param in net.get_parameters():
print(param.name, param.data.asnumpy())
ms.save_checkpoint(net, "dense.ckpt")
dense_params = ms.load_checkpoint("dense.ckpt")
print(dense_params)
new_params = {}
for param_name in dense_params:
print(param_name, dense_params[param_name].data.asnumpy())
new_params[param_name] = ms.Parameter(ops.ones_like(dense_params[param_name].data), name=param_name)
ms.load_param_into_net(net, new_params)
for param in net.get_parameters():
print(param.name, param.data.asnumpy())
```
运行结果:
```text
weight [[-0.0042482 -0.00427286]]
bias [0.]
{'weight': Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), 'bias': Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)}
weight [[-0.0042482 -0.00427286]]
bias [0.]
weight [[1. 1.]]
bias [1.]
```
|
| torch.nn.init | mindspore.common.initializer |
```python import torch x = torch.empty(2, 2) torch.nn.init.uniform_(x) ``` |
```python import mindspore from mindspore.common.initializer import initializer, Uniform x = initializer(Uniform(), [1, 2, 3], mindspore.float32) ``` |
| PyTorch | MindSpore |
```python
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
# 使用add_module添加子模块
self.add_module('conv3', nn.Conv2d(64, 128, 3, 1))
self.sequential_block = nn.Sequential(
nn.ReLU(),
nn.Conv2d(128, 256, 3, 1),
nn.ReLU()
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.sequential_block(x)
return x
module = MyModule()
# 使用named_modules遍历所有子模块(包括直接和间接子模块)
for name, module_instance in module.named_modules():
print(f"Module name: {name}, type: {type(module_instance)}")
```
运行结果:
```text
Module name: , type:
|
```python
from mindspore import nn
class MyCell(nn.Cell):
def __init__(self):
super(MyCell, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
# 使用insert_child_to_cell添加子模块
self.insert_child_to_cell('conv3', nn.Conv2d(64, 128, 3, 1))
self.sequential_block = nn.SequentialCell(
nn.ReLU(),
nn.Conv2d(128, 256, 3, 1),
nn.ReLU()
)
def construct(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.sequential_block(x)
return x
module = MyCell()
# 使用cells_and_names遍历所有子模块(包括直接和间接子模块)
for name, cell_instance in module.cells_and_names():
print(f"Cell name: {name}, type: {type(cell_instance)}")
```
运行结果:
```text
Cell name: , type:
|
| PyTorch | MindSpore |
```python import torch torch_net = torch.nn.Linear(3, 4) torch_net.cpu() ``` |
```python import mindspore mindspore.set_context(device_target="CPU") ms_net = mindspore.nn.Dense(3, 4) ``` |
| PyTorch | MindSpore |
```python
def box_select_torch(box, iou_score):
mask = iou_score > 0.3
return box[mask]
```
|
```python
import mindspore as ms
from mindspore import ops
ms.set_seed(1)
def box_select_ms(box, iou_score):
mask = (iou_score > 0.3).expand_dims(1)
return ops.masked_select(box, mask)
```
|
| PyTorch | MindSpore |
```python
import torch
import torch.nn as torch_nn
class ClassLoss_pt(torch_nn.Module):
def __init__(self):
super(ClassLoss_pt, self).__init__()
self.con_loss = torch_nn.CrossEntropyLoss(reduction='none')
# 使用 torch.topk 来获取前70%的正样本数据
def forward(self, pred, label):
mask = label > 0
vaild_label = label * mask
pos_num = torch.clamp(mask.sum() * 0.7, 1).int()
con = self.con_loss(pred, vaild_label.long()) * mask
loss, unused_value = torch.topk(con, k=pos_num)
return loss.mean()
```
|
```python
import mindspore as ms
from mindspore import ops
from mindspore import nn as ms_nn
class ClassLoss_ms(ms_nn.Cell):
def __init__(self):
super(ClassLoss_ms, self).__init__()
self.con_loss = ms_nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
self.sort_descending = ops.Sort(descending=True)
# MindSpore目前不支持TopK的K是变量,转换思路,获取到第K大的值,然后通过该值获取到topk的mask
def construct(self, pred, label):
mask = label > 0
vaild_label = label * mask
pos_num = ops.maximum(mask.sum() * 0.7, 1).astype(ms.int32)
con = self.con_loss(pred, vaild_label.astype(ms.int32)) * mask
con_sort, unused_value = self.sort_descending(con)
con_k = con_sort[pos_num - 1]
con_mask = (con >= con_k).astype(con.dtype)
loss = con * con_mask
return loss.sum() / con_mask.sum()
```
|