# 模型迁移 [](https://gitee.com/mindspore/docs/blob/r2.6.0/tutorials/source_zh_cn/model_migration/model_migration.md) 本章节主要对模型迁移场景所必须的数据集、模型和训练、推理流程等在MindSpore上构建方法做简单的介绍。同时展示了MindSpore和PyTorch在数据集包装、模型构建、训练流程代码上的差别。 ## 模型分析 在进行正式的代码迁移前,需要对即将进行迁移的代码做一些简单的分析,判断哪些代码可以直接复用,哪些代码必须迁移到MindSpore。 一般的,只有与硬件相关的代码部分,才必须要迁移到MindSpore,比如: - 模型输入相关,包含模型参数加载,数据集包装等; - 模型构建和执行的代码; - 模型输出相关,包含模型参数保存等。 像Numpy、OpenCV等CPU上计算的三方库,以及Configuration、Tokenizer等不需要昇腾、GPU处理的Python操作,可以直接复用原始代码。 ## 数据集包装 MindSpore提供了多种典型开源数据集的解析读取,如MNIST、CIFAR-10、CLUE、LJSpeech等,详情可参考[mindspore.dataset](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/mindspore.dataset.html)。 ### 自定义数据加载 GeneratorDataset 在迁移场景,最常用的数据加载方式是[GeneratorDataset](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/dataset/mindspore.dataset.GeneratorDataset.html#mindspore.dataset.GeneratorDataset),只需对Python迭代器做简单包装,就可以直接对接MindSpore模型进行训练、推理。 ```python import numpy as np from mindspore import dataset as ds num_parallel_workers = 2 # 多线程/进程数 world_size = 1 # 并行场景使用,通信group_size rank = 0 # 并行场景使用,通信rank_id class MyDataset: def __init__(self): self.data = np.random.sample((5, 2)) self.label = np.random.sample((5, 1)) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) dataset = ds.GeneratorDataset(source=MyDataset(), column_names=["data", "label"], num_parallel_workers=num_parallel_workers, shuffle=True, num_shards=1, shard_id=0) train_dataset = dataset.batch(batch_size=2, drop_remainder=True, num_parallel_workers=num_parallel_workers) ``` 一个典型的数据集构造如上:构造一个Python类,必须有\_\_getitem\_\_和\_\_len\_\_方法,分别表示每一步迭代取的数据和整个数据集遍历一次的大小,其中index表示每次取数据的索引,当shuffle=False时按顺序递增,当shuffle=True时随机打乱。 GeneratorDataset至少需要包含: - source:一个Python迭代器; - column_names:迭代器\_\_getitem\_\_方法每个输出的名字。 更多使用方法参考[GeneratorDataset](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/dataset/mindspore.dataset.GeneratorDataset.html#mindspore.dataset.GeneratorDataset)。 dataset.batch将数据集中连续batch_size条数据,组合为一个批数据,至少需要包含: - batch_size:指定每个批处理数据包含的数据条目。 更多使用方法参考[Dataset.batch](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/dataset/dataset_method/batch/mindspore.dataset.Dataset.batch.html)。 ### 与PyTorch数据集构建差别  MindSpore的GeneratorDataset与PyTorch的DataLoader的主要差别有: - MindSpore的GeneratorDataset必须传入column_names; - PyTorch的数据增强输入的对象是Tensor类型,MindSpore的数据增强输入的对象是numpy类型,且数据处理不能用MindSpore的mint、ops和nn算子; - PyTorch的batch操作是DataLoader的属性,MindSpore的batch操作是独立的方法。 详细可参考[与torch.utils.data.DataLoader的差异](https://www.mindspore.cn/docs/zh-CN/r2.6.0/note/api_mapping/pytorch_diff/DataLoader.html)。 ## 模型构建 ### 网络基本构成单元 Cell MindSpore的网络搭建主要使用Cell进行图的构造,用户需要定义一个类继承Cell这个基类,在init里声明需要使用的API及子模块,在construct里进行计算:
PyTorch | MindSpore |
```python import torch class Network(torch.nn.Module): def __init__(self, forward_net): super(Network, self).__init__() self.net = forward_net def forward(self, x): y = self.net(x) return torch.nn.functional.relu(y) inner_net = torch.nn.Conv2d(120, 240, kernel_size=4, bias=False) net = Network(inner_net) for i in net.parameters(): print(i) ``` |
```python from mindspore import mint, nn class Network(nn.Cell): def __init__(self, forward_net): super(Network, self).__init__() self.net = forward_net def construct(self, x): y = self.net(x) return mint.nn.functional.relu(y) inner_net = mint.nn.Conv2d(120, 240, kernel_size=4, bias=False) net = Network(inner_net) for i in net.get_parameters(): print(i) ``` |
PyTorch | MindSpore |
```python # 使用torch.save()把获取到的state_dict保存到pkl文件中 torch.save(pt_model.state_dict(), save_path) # 使用torch.load()加载保存的state_dict, # 然后使用load_state_dict将获取到的state_dict加载到模型中 state_dict = torch.load(save_path) pt_model.load_state_dict(state_dict) ``` |
```python # 模型权重保存: ms.save_checkpoint(ms_model, save_path) # 使用ms.load_checkpoint()加载保存的ckpt文件, # 然后使用load_state_dict将获取到的param_dict加载到模型中 param_dict = ms.load_checkpoint(save_path) ms_model.load_state_dict(param_dict) ``` |
PyTorch | MindSpore |
```python optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) scheduler = ExponentialLR(optimizer, gamma=0.9) optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() scheduler.step() ``` |
```python import mindspore from mindspore import nn lr = nn.exponential_decay_lr(0.01, decay_rate, total_step, step_per_epoch, decay_epoch) optimizer = nn.SGD(model.trainable_params(), learning_rate=lr, momentum=0.9) grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) (loss, _), grads = grad_fn(data, label) # 在优化器里自动做学习率更新 optimizer(grads) ``` |
PyTorch的自动微分 | MindSpore的自动微分 |
```python # torch.autograd: # backward是累计的,更新完之后需清空optimizer import torch.nn as nn import torch.optim as optim # 实例化模型和优化器 model = PT_Model() optimizer = optim.SGD(model.parameters(), lr=0.01) # 定义损失函数:均方误差(MSE) loss_fn = nn.MSELoss() # 前向传播:计算模型输出 y_pred = model(x) # 计算损失:将预测值与真实标签计算损失 loss = loss_fn(y_pred, y_true) # 反向传播:计算梯度 loss.backward() # 优化器更新 optimizer.step() ``` |
```python # ms.grad: # 使用grad接口,输入正向图,输出反向图 import mindspore as ms from mindspore import nn # 实例化模型和优化器 model = MS_Model() optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01) # 定义损失函数:均方误差(MSE) loss_fn = nn.MSELoss() def forward_fn(x, y_true): # 前向传播:计算模型输出 y_pred = model(x) # 计算损失:将预测值与真实标签计算损失 loss = loss_fn(y_pred, y_true) return loss, y_pred # 计算loss和梯度 grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) (loss, _), grads = grad_fn(x, y_true) # 优化器更新 optimizer(grads) ``` |