{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.6.0/tutorials/zh_cn/beginner/mindspore_save_load.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.6.0/tutorials/zh_cn/beginner/mindspore_save_load.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/tutorials/source_zh_cn/beginner/save_load.ipynb)\n", "\n", "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/tensor.html) || [数据加载与处理](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/dataset.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/model.html) || [函数式自动微分](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/autograd.html) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/train.html) || **保存与加载** || [Graph Mode加速](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/accelerate_with_static_graph.html) || [自动混合精度](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/mixed_precision.html) ||" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# 保存与加载\n", "\n", "上一章节主要介绍了如何调整超参数,并进行网络模型训练。在训练网络模型的过程中,通常希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import mindspore\n", "from mindspore import nn\n", "from mindspore import Tensor" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def network():\n", " model = nn.SequentialCell(\n", " nn.Flatten(),\n", " nn.Dense(28*28, 512),\n", " nn.ReLU(),\n", " nn.Dense(512, 512),\n", " nn.ReLU(),\n", " nn.Dense(512, 10))\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 保存和加载模型权重\n", "\n", "保存模型使用`save_checkpoint`接口,传入网络和指定的保存路径:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "model = network()\n", "mindspore.save_checkpoint(model, \"model.ckpt\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "为了加载模型权重,需要先创建相同模型的实例,然后使用`load_checkpoint`和`load_param_into_net`方法加载参数。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = network()\n", "param_dict = mindspore.load_checkpoint(\"model.ckpt\")\n", "param_not_load, _ = mindspore.load_param_into_net(model, param_dict)\n", "print(param_not_load)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "> - `param_not_load`是未被加载的参数列表,为空时代表所有参数均加载成功。\n", "> - 当环境中安装有MindX DL(昇腾深度学习组件)6.0及以上版本时,默认启动MindIO加速CheckPoint功能,详情查看[MindIO介绍](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindioacp/mindioacp001.html)。MindX DL在[此处](https://www.hiascend.com/developer/download/community/result?module=dl+cann)下载。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 保存和加载MindIR\n", "\n", "除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的[中间表示(Intermediate Representation,IR)](https://www.mindspore.cn/docs/zh-CN/r2.6.0/design/all_scenarios.html#中间表示mindir)。可使用`export`接口直接将模型保存为MindIR(当前仅支持严格图模式)。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "mindspore.set_context(mode=mindspore.GRAPH_MODE, jit_syntax_level=mindspore.STRICT)\n", "model = network()\n", "inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))\n", "mindspore.export(model, inputs, file_name=\"model\", file_format=\"MINDIR\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "已有的MindIR模型可以方便地通过`load`接口加载,传入`nn.GraphCell`即可进行推理。\n", "\n", "> `nn.GraphCell`仅支持图模式。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 10)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph = mindspore.load(\"model.mindir\")\n", "model = nn.GraphCell(graph)\n", "outputs = model(inputs)\n", "print(outputs.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 语法支持范围\n", "\n", "并不是所有的 Python 语法和数据类型都支持 MindIR 导出,MindIR 导出有特定的支持范围,若不在支持范围内,导出时会报错。\n", "\n", "首先,MindIR 导出仅支持**严格图模式**,详细的支持范围,可参考[静态图语法支持](https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/compile/static_graph.html)。\n", "\n", "其次,除了严格图模式的语法限制,MindIR 对返回值的类型有额外约束,比如不支持返回`mindspore.dtype`。例如下面的程序,MindIR 导出的时候就会报错。\n", "\n", "```python\n", "import mindspore as ms\n", "from mindspore import nn, ops, Tensor\n", "\n", "class Model(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.dtype = ops.DType()\n", "\n", " def construct(self, x: Tensor) -> ms.dtype:\n", " return self.dtype(x)\n", "```\n", "\n", "另外,如果在`nn.Cell`外创建了`Parameter`对象,MindIR 不支持导出该`Parameter`。这种情况通常发生在:\n", "\n", "- 直接在脚本的全局作用域中创建了`Parameter`。\n", "- 在非`nn.Cell`类中创建了`Parameter`。\n", "- 使用了 [mindspore.mint](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/mindspore.mint.html) 包下的随机数生成接口,如`mint.randn`、`mint.randperm`等等,因为这些随机数接口会在全局作用域中创建`Parameter`。\n", "\n", "例如下面的两个程序,在导出过程中,就会报错。\n", "\n", "```python\n", "from mindspore import Tensor, Parameter, nn\n", "\n", "param = Parameter(Tensor([1, 2, 3, 4])) # 在nn.Cell外创建\n", "\n", "class Model(nn.Cell):\n", " def construct(self, x: Tensor) -> Tensor:\n", " return x + param\n", "```\n", "\n", "```python\n", "from mindspore import Tensor, nn, mint\n", "\n", "class Model(nn.Cell):\n", " def construct(self, n: int) -> Tensor:\n", " return mint.randn(n)\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }