# Network Migration Debugging Example [](https://gitee.com/mindspore/docs/blob/r2.3.0/docs/mindspore/source_en/migration_guide/sample_code.md) The following uses the classic network ResNet50 as an example to describe the network migration method in detail based on the code. ## Model Analysis and Preparation Assume that the MindSpore operating environment has been configured according to [Environment Preparation and Information Acquisition](https://www.mindspore.cn/docs/en/r2.3.0/migration_guide/enveriment_preparation.html). Assume that ResNet-50 has not been implemented in the models repository. First, analyze the algorithm and network structure. The Residual Neural Network (ResNet) was proposed by Kaiming He et al. from Microsoft Research Institute. They used residual units to successfully train a 152-layer neural network, and thus became the winner of ILSVRC 2015. A conventional convolutional network or fully-connected network has more or less information losses, and further causes gradient disappearance or explosion. As a result, deep network training fails. The ResNet can solve these problems to some extent. By passing the input information to the output, the information integrity is protected. The network only needs to learn the differences between the input and output, simplifying the learning objective and difficulty. Its structure can accelerate training of a neural network and greatly improve the accuracy of the network model. [Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun."Deep Residual Learning for Image Recognition" The [sample code of PyTorch ResNet-50 CIFAR-10](https://gitee.com/mindspore/docs/tree/r2.3.0/docs/mindspore/source_zh_cn/migration_guide/code/resnet_convert/resnet_pytorch) contains the PyTorch ResNet implementation, CIFAR-10 data processing, network training, and inference processes. ### Checklist When reading the paper and referring to the implementation, analyze and fill in the following checklist: |Trick|Record| |----|----| |Data augmentation| RandomCrop, RandomHorizontalFlip, Resize, Normalize| |Learning rate attenuation policy| Fixed learning rate = 0.001| |Optimization parameters| Adam optimizer, weight_decay = 1e-5| |Training parameters| batch_size = 32, epochs = 90| |Network structure optimization| Bottleneck | |Training process optimization| None| ### Reproducing Reference Implementation Download the PyTorch code and CIFAR-10 dataset to train the network. ```text Train Epoch: 89 [0/1563 (0%)] Loss: 0.010917 Train Epoch: 89 [100/1563 (6%)] Loss: 0.013386 Train Epoch: 89 [200/1563 (13%)] Loss: 0.078772 Train Epoch: 89 [300/1563 (19%)] Loss: 0.031228 Train Epoch: 89 [400/1563 (26%)] Loss: 0.073462 Train Epoch: 89 [500/1563 (32%)] Loss: 0.098645 Train Epoch: 89 [600/1563 (38%)] Loss: 0.112967 Train Epoch: 89 [700/1563 (45%)] Loss: 0.137923 Train Epoch: 89 [800/1563 (51%)] Loss: 0.143274 Train Epoch: 89 [900/1563 (58%)] Loss: 0.088426 Train Epoch: 89 [1000/1563 (64%)] Loss: 0.071185 Train Epoch: 89 [1100/1563 (70%)] Loss: 0.094342 Train Epoch: 89 [1200/1563 (77%)] Loss: 0.126669 Train Epoch: 89 [1300/1563 (83%)] Loss: 0.245604 Train Epoch: 89 [1400/1563 (90%)] Loss: 0.050761 Train Epoch: 89 [1500/1563 (96%)] Loss: 0.080932 Test set: Average loss: -9.7052, Accuracy: 91% Finished Training ``` You can download training logs and saved parameter files from [resnet_pytorch_res](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/resnet_pytorch_res.zip). ### Analyzing API/Feature Missing - API analysis | PyTorch API | MindSpore API| Different or Not| | ---------------------- | ------------------ | ------| | `nn.Conv2D` | `nn.Conv2d` | Yes. [Difference](https://www.mindspore.cn/docs/en/r2.3.0/note/api_mapping/pytorch_diff/Conv2d.html)| | `nn.BatchNorm2D` | `nn.BatchNom2d` | Yes. [Difference](https://www.mindspore.cn/docs/en/r2.3.0/note/api_mapping/pytorch_diff/BatchNorm2d.html)| | `nn.ReLU` | `nn.ReLU` | No| | `nn.MaxPool2D` | `nn.MaxPool2d` | Yes. [Difference](https://www.mindspore.cn/docs/en/r2.3.0/note/api_mapping/pytorch_diff/MaxPool2d.html)| | `nn.AdaptiveAvgPool2D` | `nn.AdaptiveAvgPool2D` | No | | `nn.Linear` | `nn.Dense` | Yes. [Difference](https://www.mindspore.cn/docs/en/r2.3.0/note/api_mapping/pytorch_diff/Dense.html)| | `torch.flatten` | `nn.Flatten` | No| By using [MindSpore Dev Toolkit](https://www.mindspore.cn/docs/en/r2.3.0/migration_guide/migrator_with_tools.html#network-migration-development) tool or checking [PyTorch API Mapping](https://www.mindspore.cn/docs/en/r2.3.0/note/api_mapping/pytorch_api_mapping.html), we find that four APIs are different. - Function analysis | PyTorch Function | MindSpore Function | | ------------------------- | ------------------------------------- | | `nn.init.kaiming_normal_` | `initializer(init='HeNormal')` | | `nn.init.constant_` | `initializer(init='Constant')` | | `nn.Sequential` | `nn.SequentialCell` | | `nn.Module` | `nn.Cell` | | `nn.distibuted` | `set_auto_parallel_context` | | `torch.optim.SGD` | `nn.optim.SGD` or `nn.optim.Momentum` | (The interface design of MindSpore is different from that of PyTorch. Therefore, only the comparison of key functions is listed here.) After API and function analysis, we find that there are no missing APIs and functions on MindSpore compared with PyTorch. ## MindSpore Model Implementation ### Datasets The CIFAR-10 dataset is as follows: ```text └─dataset_path ├─cifar-10-batches-bin # train dataset ├─ data_batch_1.bin ├─ data_batch_2.bin ├─ data_batch_3.bin ├─ data_batch_4.bin ├─ data_batch_5.bin └─cifar-10-verify-bin # evaluate dataset ├─ test_batch.bin ``` This operation is implemented on PyTorch/MindSpore as follows:
| PyTorch Dataset Processing | MindSpore Dataset Processing |
```python
import torch
import torchvision.transforms as trans
import torchvision
train_transform = trans.Compose([
trans.RandomCrop(32, padding=4),
trans.RandomHorizontalFlip(0.5),
trans.Resize(224),
trans.ToTensor(),
trans.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010]),
])
test_transform = trans.Compose([
trans.Resize(224),
trans.RandomHorizontalFlip(0.5),
trans.ToTensor(),
trans.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010]),
])
# If necessary, you can set download=True in the datasets.CIFAR10 interface to download automatically.
train_set = torchvision.datasets.CIFAR10(root='./data',
train=True,
transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=32,
shuffle=True)
test_set = torchvision.datasets.CIFAR10(root='./data',
train=False,
transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=1,
shuffle=False)
```
|
```python
import mindspore as ms
import mindspore.dataset as ds
from mindspore.dataset import vision
from mindspore.dataset.transforms import TypeCast
def create_cifar_dataset(dataset_path, do_train, batch_size=32,
image_size=(224, 224),
rank_size=1, rank_id=0):
dataset = ds.Cifar10Dataset(dataset_path,
shuffle=do_train,
num_shards=rank_size,
shard_id=rank_id)
# define map operations
trans = []
if do_train:
trans += [
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5)
]
trans += [
vision.Resize(image_size),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
type_cast_op = TypeCast(ms.int32)
data_set = dataset.map(operations=type_cast_op,
input_columns="label")
data_set = data_set.map(operations=trans,
input_columns="image")
# apply batch operations
data_set = data_set.batch(batch_size,
drop_remainder=do_train)
return data_set
```
|
| PyTorch | MindSpore |
```python
nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
```
|
```python
nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
pad_mode="pad",
stride=stride,
padding=dilation,
group=groups,
has_bias=False,
dilation=dilation,
)
```
|
```python nn.Module ``` |
```python nn.Cell ``` |
```python nn.ReLU(inplace=True) ``` |
```python nn.ReLU() ``` |
```python # PyTorch graph construction forward ``` |
```python # MindSpore graph construction construct ``` |
```python
# PyTorch MaxPool2d with padding
maxpool = nn.MaxPool2d(kernel_size=3,
stride=2,
padding=1)
```
|
```python
# MindSpore MaxPool2d with padding
maxpool = nn.SequentialCell([
nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)),
mode="CONSTANT"),
nn.MaxPool2d(kernel_size=3, stride=2)])
```
|
```python # PyTorch AdaptiveAvgPool2d avgpool = nn.AdaptiveAvgPool2d((1, 1)) ``` |
```python # When PyTorch AdaptiveAvgPool2d output shape is set to 1, # MindSpore ReduceMean functions the same with higher speed. mean = ops.ReduceMean(keep_dims=True) ``` |
```python # PyTorch Full Connection fc = nn.Linear(512 * block.expansion, num_classes) ``` |
```python # MindSpore Full Connection fc = nn.Dense(512 * block.expansion, num_classes) ``` |
```python # PyTorch Sequential nn.Sequential ``` |
```python # MindSpore SequentialCell nn.SequentialCell ``` |
```python
# PyTorch Initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight,
mode="fan_out",
nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(
m.weight,
1)
nn.init.constant_(
m.bias,
0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros,
# and each residual block behaves like an identity.
# This improves the model by 0.2~0.3%.
# Reference: https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
is_bottleneck = isinstance(m, Bottleneck)
is_basicblock = isinstance(m, BasicBlock)
if is_bottleneck and m.bn3.weight is not None:
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif is_basicblock and m.bn2.weight is not None:
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)
```
|
```python
# MindSpore Initialization
from mindspore import common.initializer
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer.initializer(
initializer.HeNormal(negative_slope=0, mode='fan_out',
nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.GroupNorm)):
cell.gamma.set_data(
initializer.initializer("ones", cell.gamma.shape,
cell.gamma.dtype))
cell.beta.set_data(
initializer.initializer("zeros", cell.beta.shape,
cell.beta.dtype))
elif isinstance(cell, (nn.Dense)):
cell.weight.set_data(initializer.initializer(
initializer.HeUniform(negative_slope=math.sqrt(5)),
cell.weight.shape, cell.weight.dtype))
cell.bias.set_data(
initializer.initializer("zeros", cell.bias.shape,
cell.bias.dtype))
if zero_init_residual:
for _, cell in self.cells_and_names():
is_bottleneck = isinstance(cell, Bottleneck)
is_basicblock = isinstance(cell, BasicBlock)
if is_bottleneck and cell.bn3.gamma is not None:
cell.bn3.gamma.set_data("zeros", cell.bn3.gamma.shape,
cell.bn3.gamma.dtype)
elif is_basicblock and cell.bn2.weight is not None:
cell.bn2.gamma.set_data("zeros", cell.bn2.gamma.shape,
cell.bn2.gamma.dtype)
```
|
| PyTorch | MindSpore |
```python net_loss = torch.nn.CrossEntropyLoss() ``` |
```python loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') ``` |
| PyTorch | MindSpore |
```python
net_opt = torch.optim.Adam(net.parameters(),
0.001,
weight_decay=1e-5)
```
|
```python
optimizer = ms.nn.Adam(resnet.trainable_params(),
0.001,
weight_decay=1e-5)
```
|
| PyTorch | MindSpore |
```python
# Print the names and shapes of all parameters in the PyTorch cell
# Return the parameter dictionary
def pytorch_params(pth_file):
par_dict = torch.load(pth_file, map_location='cpu')
pt_params = {}
for name in par_dict:
parameter = par_dict[name]
print(name, parameter.numpy().shape)
pt_params[name] = parameter.numpy()
return pt_params
pth_path = "resnet.pth"
pt_param = pytorch_params(pth_path)
print("="*20)
```
Result:
```text
conv1.weight (64, 3, 7, 7)
bn1.weight (64,)
bn1.bias (64,)
bn1.running_mean (64,)
bn1.running_var (64,)
bn1.num_batches_tracked ()
layer1.0.conv1.weight (64, 64, 1, 1)
```
|
```python
# Print the names and shapes of all parameters in the MindSpore cell
# Return the parameter dictionary
def mindspore_params(network):
ms_params = {}
for param in network.get_parameters():
name = param.name
value = param.data.asnumpy()
print(name, value.shape)
ms_params[name] = value
return ms_params
from resnet_ms.src.resnet import resnet50 as ms_resnet50
ms_param = mindspore_params(ms_resnet50(num_classes=10))
print("="*20)
```
Result:
```text
conv1.weight (64, 3, 7, 7)
bn1.moving_mean (64,)
bn1.moving_variance (64,)
bn1.gamma (64,)
bn1.beta (64,)
layer1.0.conv1.weight (64, 64, 1, 1)
```
|
| PyTorch | MindSpore |
```python
import torch
import torchvision.transforms as trans
import torchvision
import torch.nn.functional as F
from resnet import resnet50
def test_epoch(model, device, data_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in data_loader:
output = model(data.to(device))
# sum up batch loss
test_loss += F.nll_loss(output, target.to(device),
reduction='sum').item()
# get the index of the max log-probability
pred = output.max(1)
pred = pred[1]
correct += pred.eq(target.to(device)).sum().item()
test_loss /= len(data_loader.dataset)
print('\nLoss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, 100. * correct / len(data_loader.dataset)))
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
test_transform = trans.Compose([
trans.Resize(224),
trans.RandomHorizontalFlip(0.5),
trans.ToTensor(),
trans.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010]),
])
test_set = torchvision.datasets.CIFAR10(
root='./data', train=False, transform=test_transform)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=1, shuffle=False)
# 2. define forward network
if use_cuda:
net = resnet50(num_classes=10).cuda()
else:
resnet50(num_classes=10)
net.load_state_dict(torch.load("./resnet.pth", map_location='cpu'))
test_epoch(net, device, test_loader)
```
|
```python
import numpy as np
import mindspore as ms
from mindspore import nn
from src.dataset import create_dataset
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.utils import init_env
from src.resnet import resnet50
def test_epoch(model, data_loader, loss_func):
model.set_train(False)
test_loss = 0
correct = 0
for data, target in data_loader:
output = model(data)
test_loss += float(loss_func(output, target).asnumpy())
pred = np.argmax(output.asnumpy(), axis=1)
correct += (pred == target.asnumpy()).sum()
dataset_size = data_loader.get_dataset_size()
test_loss /= dataset_size
print('\nLoss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, 100. * correct / dataset_size))
@moxing_wrapper()
def test_net():
init_env(config)
eval_dataset = create_dataset(
config.dataset_name,
config.data_path,
False, batch_size=1,
image_size=(int(config.image_height),
int(config.image_width)))
resnet = resnet50(num_classes=config.class_num)
ms.load_checkpoint(config.checkpoint_path, resnet)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,
reduction='mean')
test_epoch(resnet, eval_dataset, loss)
if __name__ == '__main__':
test_net()
```
|
Execute: ```shell python test.py --data_path data/cifar10/ --checkpoint_path resnet.ckpt ``` |
|
Result: ```text Loss: -9.7075, Accuracy: 91% ``` |
Result: ```text run standalone! Loss: 0.3240, Accuracy: 91% ``` |