{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model基本使用\n", "\n", "`Ascend` `GPU` `CPU` `模型开发`\n", "\n", "[![在线运行](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX21vZGVsX3VzZV9ndWlkZS5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![下载Notebook](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_model_use_guide.ipynb) [![下载样例代码](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_model_use_guide.py) [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.6/docs/mindspore/programming_guide/source_zh_cn/model_use_guide.ipynb)\n", "\n", "## 概述\n", "\n", "[编程指南](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/index.html)的网络构建部分讲述了如何定义前向网络、损失函数和优化器,并介绍了如何将这些结构封装成训练、评估网络并执行。在此基础上,本文档讲述如何使用高阶API`Model`进行模型训练和评估。\n", "\n", "通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求,但仍然建议通过`Model`来进行模型训练和评估。一方面,`Model`可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义`TrainOneStepCell`的场景下,可以借助`Model`自动构建训练网络;可以使用`Model`的`eval`接口进行模型评估,直接输出评估结果,无需手动调用评价指标的`clear`、`update`、`eval`函数等。另一方面,`Model`提供了很多高阶功能,如数据下沉、混合精度等,在不借助`Model`的情况下,使用这些功能需要花费较多的时间仿照`Model`进行自定义。\n", "\n", "本文档首先对Model进行基本介绍,然后重点讲解如何使用`Model`进行模型训练、评估和推理。\n", "\n", "> 下述例子中,参数初始化使用了随机值,在具体执行中输出的结果可能与本地执行输出的结果不同;如果需要稳定输出固定的值,可以设置固定的随机种子,设置方法请参考[mindspore.set_seed()](https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/mindspore/mindspore.set_seed.html)。\n", "\n", "## Model基本介绍\n", "\n", "`Model`是MindSpore提供的模型训练高阶API,可以进行模型训练、评估和推理。\n", "\n", "`Model`中包含入参:\n", "\n", "- network (Cell):一般情况下为前向网络,输入数据和标签,输出预测值。\n", "\n", "- loss_fn (Cell):所使用的损失函数。\n", "\n", "- optimizer (Cell):所使用的优化器。\n", "\n", "- metrics (set):进行模型评估时使用的评价指标,在不需要模型评估时使用默认值`None`。\n", "\n", "- eval_network (Cell):模型评估所使用的网络,在部分简单场景下不需要指定。\n", "\n", "- eval_indexes (List):用于指示评估网络输出的含义,配合`eval_network`使用,该参数的功能可通过`nn.Metric`的`set_indexes`代替,建议使用`set_indexes`。\n", "\n", "- amp_level (str):用于指定混合精度级别。设置为\"O0\"或\"O3\"级别时,默认不使用LossScale,设置为\"O2\"级别时,默认使用动态LossScale策略。通过kwargs可修改LossScale策略。\n", "\n", "- boost_level (str):用于指定Boost级别。\n", "\n", "- kwargs:用于指定混合精度、LossScale、Boost相关策略。\n", "\n", "`Model`提供了以下接口用于模型训练、评估和推理:\n", "\n", "- train:用于在训练集上进行模型训练。\n", "\n", "- eval:用于在验证集上进行模型评估。\n", "\n", "- predict:用于对输入的一组数据进行推理,输出预测结果。\n", "\n", "相关特性如下:\n", "\n", "- [混合精度](https://mindspore.cn/docs/programming_guide/zh-CN/r1.6/enable_mixed_precision.html)\n", "- [损失缩放](https://mindspore.cn/docs/programming_guide/zh-CN/r1.6/lossscale.html)\n", "- [分布式并行](https://mindspore.cn/docs/programming_guide/zh-CN/r1.6/distributed_training.html)\n", "- [自适应梯度求和](https://mindspore.cn/docs/programming_guide/zh-CN/r1.6/apply_adaptive_summation.html)\n", "- [降维训练](https://mindspore.cn/docs/programming_guide/zh-CN/r1.6/apply_dimention_reduce_training.html)\n", "\n", "## 模型训练、评估和推理\n", "\n", "对于简单场景的神经网络,可以在定义`Model`时指定前向网络`network`、损失函数`loss_fn`、优化器`optimizer`和评估指标`metrics`。此时,Model会使用`network`作为推理网络,并使用`nn.WithLossCell`和`nn.TrainOneStepCell`构建训练网络,使用`nn.WithEvalCell`构建评估网络。\n", "\n", "以[构建训练与评估网络](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/train_and_eval.html)中使用的线性回归为例:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:30.392367Z", "start_time": "2022-01-04T06:43:28.436687Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "from mindspore.common.initializer import Normal\n", "\n", "class LinearNet(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))\n", "\n", " def construct(self, x):\n", " return self.fc(x)\n", "\n", "net = LinearNet()\n", "# 设定损失函数\n", "crit = nn.MSELoss()\n", "# 设定优化器\n", "opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)\n", "# 设定评价指标\n", "metrics = {\"mae\"}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[构建训练与评估网络](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/train_and_eval.html)中讲述了通过`nn.WithLossCell`、`nn.TrainOneStepCell`和`nn.WithEvalCell`构建训练和评估网络并直接运行方式。使用`Model`时则不需要手动构建训练和评估网络,用以下方式定义`Model`并调用`train`和`eval`接口能够达到相同的效果。\n", "\n", "创建训练集和验证集:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:30.523624Z", "start_time": "2022-01-04T06:43:30.394412Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "def get_data(num, w=2.0, b=3.0):\n", " for _ in range(num):\n", " x = np.random.uniform(-10.0, 10.0)\n", " noise = np.random.normal(0, 1)\n", " y = x * w + b + noise\n", " yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)\n", "\n", "def create_dataset(num_data, batch_size=16):\n", " dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])\n", " dataset = dataset.batch(batch_size)\n", " return dataset\n", "\n", "# 创建数据集\n", "train_dataset = create_dataset(num_data=160)\n", "eval_dataset = create_dataset(num_data=80)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义Model并进行模型训练,通过`LossMonitor`回调函数查看在训练过程中的损失函数值:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:30.952190Z", "start_time": "2022-01-04T06:43:30.525149Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1, loss is 145.92702\n", "epoch: 1 step: 2, loss is 64.700424\n", "epoch: 1 step: 3, loss is 6.557139\n", "epoch: 1 step: 4, loss is 28.182562\n", "epoch: 1 step: 5, loss is 47.01909\n", "epoch: 1 step: 6, loss is 84.31042\n", "epoch: 1 step: 7, loss is 90.51094\n", "epoch: 1 step: 8, loss is 8.339919\n", "epoch: 1 step: 9, loss is 6.141535\n", "epoch: 1 step: 10, loss is 43.75537\n", "epoch: 2 step: 1, loss is 35.802357\n", "epoch: 2 step: 2, loss is 44.16835\n", "epoch: 2 step: 3, loss is 23.275116\n", "epoch: 2 step: 4, loss is 2.7919912\n", "epoch: 2 step: 5, loss is 10.78731\n", "epoch: 2 step: 6, loss is 44.453777\n", "epoch: 2 step: 7, loss is 32.597885\n", "epoch: 2 step: 8, loss is 16.546305\n", "epoch: 2 step: 9, loss is 0.89161384\n", "epoch: 2 step: 10, loss is 6.7151947\n" ] } ], "source": [ "from mindspore import Model\n", "from mindspore.train.callback import LossMonitor\n", "\n", "model = Model(network=net, loss_fn=crit, optimizer=opt, metrics=metrics)\n", "# 获取训练过程数据\n", "epochs = 2\n", "model.train(epochs, train_dataset, callbacks=[LossMonitor()], dataset_sink_mode=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "执行模型评估,获取评估结果:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.009605Z", "start_time": "2022-01-04T06:43:30.954322Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'mae': 4.40472822189331}\n" ] } ], "source": [ "eval_result = model.eval(eval_dataset)\n", "print(eval_result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用`predict`进行推理:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.042129Z", "start_time": "2022-01-04T06:43:31.011659Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 7.277156 ]\n", " [-4.7413435 ]\n", " [10.358773 ]\n", " [ 5.426777 ]\n", " [-4.0304213 ]\n", " [ 0.8506899 ]\n", " [-2.249575 ]\n", " [-8.091422 ]\n", " [-2.5145457 ]\n", " [-6.993862 ]\n", " [ 5.4423256 ]\n", " [ 4.464966 ]\n", " [ 5.0804405 ]\n", " [ 0.04378986]\n", " [ 2.9789488 ]\n", " [-0.4881246 ]]\n" ] } ], "source": [ "for d in eval_dataset.create_dict_iterator():\n", " data = d[\"data\"]\n", " break\n", "\n", "output = model.predict(data)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一般情况下需要对推理结果进行后处理才能得到比较直观的推理结果。\n", "\n", "与构建网络后直接运行不同,使用Model进行模型训练、推理和评估时,不需要`set_train`配置网络结构的执行模式。\n", "\n", "## 自定义场景的Model应用\n", "\n", "在[损失函数](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/loss.html)和[构建训练与评估网络](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/train_and_eval.html)中已经提到过,MindSpore提供的网络封装函数`nn.WithLossCell`、`nn.TrainOneStepCell`和`nn.WithEvalCell`并不适用于所有场景,实际场景中常常需要自定义网络的封装方式。这种情况下`Model`使用这些封装函数自动地进行封装显然是不合理的。接下来介绍这些场景下如何正确地使用`Model`。\n", "\n", "### 手动连接前向网络与损失函数\n", "\n", "在有多个数据或者多个标签的场景下,可以手动将前向网络和自定义的损失函数链接起来作为`Model`的`network`,`loss_fn`使用默认值`None`,此时`Model`内部便会直接使用`nn.TrainOneStepCell`将`network`与`optimizer`组成训练网络,而不会经过`nn.WithLossCell`。\n", "\n", "这里使用[损失函数](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/loss.html)文档中的例子:\n", "\n", "1. 定义多标签数据集" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.051039Z", "start_time": "2022-01-04T06:43:31.043718Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "def get_multilabel_data(num, w=2.0, b=3.0):\n", " for _ in range(num):\n", " x = np.random.uniform(-10.0, 10.0)\n", " noise1 = np.random.normal(0, 1)\n", " noise2 = np.random.normal(-1, 1)\n", " y1 = x * w + b + noise1\n", " y2 = x * w + b + noise2\n", " yield np.array([x]).astype(np.float32), np.array([y1]).astype(np.float32), np.array([y2]).astype(np.float32)\n", "\n", "def create_multilabel_dataset(num_data, batch_size=16):\n", " dataset = ds.GeneratorDataset(list(get_multilabel_data(num_data)), column_names=['data', 'label1', 'label2'])\n", " dataset = dataset.batch(batch_size)\n", " return dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 自定义多标签损失函数" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.066333Z", "start_time": "2022-01-04T06:43:31.052624Z" } }, "outputs": [], "source": [ "import mindspore.ops as ops\n", "from mindspore.nn import LossBase\n", "\n", "class L1LossForMultiLabel(LossBase):\n", " def __init__(self, reduction=\"mean\"):\n", " super(L1LossForMultiLabel, self).__init__(reduction)\n", " self.abs = ops.Abs()\n", "\n", " def construct(self, base, target1, target2):\n", " x1 = self.abs(base - target1)\n", " x2 = self.abs(base - target2)\n", " return self.get_loss(x1)/2 + self.get_loss(x2)/2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3. 连接前向网络和损失函数,`net`使用上一节定义的`LinearNet`" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.095795Z", "start_time": "2022-01-04T06:43:31.069437Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "\n", "class CustomWithLossCell(nn.Cell):\n", " def __init__(self, backbone, loss_fn):\n", " super(CustomWithLossCell, self).__init__(auto_prefix=False)\n", " self._backbone = backbone\n", " self._loss_fn = loss_fn\n", "\n", " def construct(self, data, label1, label2):\n", " output = self._backbone(data)\n", " return self._loss_fn(output, label1, label2)\n", "net = LinearNet()\n", "loss = L1LossForMultiLabel()\n", "loss_net = CustomWithLossCell(net, loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "4. 定义Model并进行模型训练" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.366951Z", "start_time": "2022-01-04T06:43:31.097817Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1, loss is 11.781048\n", "epoch: 1 step: 2, loss is 10.031279\n", "epoch: 1 step: 3, loss is 10.033403\n", "epoch: 1 step: 4, loss is 9.735289\n", "epoch: 1 step: 5, loss is 9.50007\n", "epoch: 1 step: 6, loss is 9.082216\n", "epoch: 1 step: 7, loss is 7.857524\n", "epoch: 1 step: 8, loss is 9.13205\n", "epoch: 1 step: 9, loss is 6.9698105\n", "epoch: 1 step: 10, loss is 3.9609227\n" ] } ], "source": [ "from mindspore.train.callback import LossMonitor\n", "from mindspore import Model\n", "\n", "opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)\n", "model = Model(network=loss_net, optimizer=opt)\n", "multi_train_dataset = create_multilabel_dataset(num_data=160)\n", "model.train(epoch=1, train_dataset=multi_train_dataset, callbacks=[LossMonitor()], dataset_sink_mode=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "5. 模型评估\n", "\n", " `Model`默认使用`nn.WithEvalCell`构建评估网络,在不满足需求的情况下同样需要手动构建评估网络,多数据和多标签便是一个典型的场景。`Model`提供了`eval_network`用于设置自定义的评估网络。手动构建评估网络的方式如下:\n", "\n", " 自定义评估网络的封装方式:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.373188Z", "start_time": "2022-01-04T06:43:31.369046Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "\n", "class CustomWithEvalCell(nn.Cell):\n", " def __init__(self, network):\n", " super(CustomWithEvalCell, self).__init__(auto_prefix=False)\n", " self.network = network\n", "\n", " def construct(self, data, label1, label2):\n", " output = self.network(data)\n", " return output, label1, label2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "手动构建评估网络:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.390466Z", "start_time": "2022-01-04T06:43:31.374781Z" }, "scrolled": true }, "outputs": [], "source": [ "eval_net = CustomWithEvalCell(net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用Model进行模型评估:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.447161Z", "start_time": "2022-01-04T06:43:31.392535Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'mae1': 5.680728149414063, 'mae2': 5.3038993835449215}\n" ] } ], "source": [ "from mindspore.train.callback import LossMonitor\n", "from mindspore import Model\n", "\n", "mae1 = nn.MAE()\n", "mae2 = nn.MAE()\n", "mae1.set_indexes([0, 1])\n", "mae2.set_indexes([0, 2])\n", "\n", "model = Model(network=loss_net, optimizer=opt, eval_network=eval_net, metrics={\"mae1\": mae1, \"mae2\": mae2})\n", "multi_eval_dataset = create_multilabel_dataset(num_data=80)\n", "result = model.eval(multi_eval_dataset, dataset_sink_mode=False)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 在进行模型评估时,评估网络的输出会透传给评估指标的`update`函数,也就是说,`update`函数将接收到三个输入,分别为`logits`、`label1`和`label2`。`nn.MAE`仅允许在两个输入上计算评价指标,因此使用`set_indexes`指定`mae1`使用下标为0和1的输入,也就是`logits`和`label1`,计算评估结果;指定`mae2`使用下标为0和2的输入,也就是`logits`和`label2`,计算评估结果。\n", "- 在实际场景中,往往需要所有标签同时参与评估,这时候就需要自定义`Metric`,灵活使用评估网络的所有输出计算评估结果。`Metric`自定义方法详见:。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "6. 推理\n", "\n", " `Model`没有提供用于指定自定义推理网络的参数,此时可以直接运行前向网络获得推理结果。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.481326Z", "start_time": "2022-01-04T06:43:31.449243Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-2.8008308 ]\n", " [ 5.6060047 ]\n", " [-2.913618 ]\n", " [-6.4244847 ]\n", " [ 5.680258 ]\n", " [-6.449928 ]\n", " [ 5.8908176 ]\n", " [-3.6554213 ]\n", " [-4.1632366 ]\n", " [ 8.986961 ]\n", " [-8.218882 ]\n", " [ 7.7593 ]\n", " [-0.14718263]\n", " [ 7.31277 ]\n", " [ 1.1601944 ]\n", " [-9.319818 ]]\n" ] } ], "source": [ "for d in multi_eval_dataset.create_dict_iterator():\n", " data = d[\"data\"]\n", " break\n", "\n", "output = net(data)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 自定义训练网络\n", "\n", "在自定义`TrainOneStepCell`时,需要手动构建训练网络作为`Model`的`network`,`loss_fn`和`optimizer`均使用默认值`None`,此时`Model`会使用`network`作为训练网络,而不会进行任何封装。\n", "\n", "自定义`TrainOneStepCell`的场景可参考[构建训练与评估网络](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/train_and_eval.html),这里列举一个简单的例子,其中`loss_net`和`opt`为上一节定义的`CustomWithLossCell`和`Momentum`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:31.684455Z", "start_time": "2022-01-04T06:43:31.482324Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1, loss is 5.2428284\n", "epoch: 1 step: 2, loss is 4.616859\n", "epoch: 1 step: 3, loss is 4.36201\n", "epoch: 1 step: 4, loss is 3.67413\n", "epoch: 1 step: 5, loss is 2.428535\n", "epoch: 1 step: 6, loss is 2.7347116\n", "epoch: 1 step: 7, loss is 2.4106321\n", "epoch: 1 step: 8, loss is 2.0769718\n", "epoch: 1 step: 9, loss is 2.5951004\n", "epoch: 1 step: 10, loss is 3.4012628\n" ] } ], "source": [ "from mindspore.nn import TrainOneStepCell as CustomTrainOneStepCell\n", "from mindspore import Model\n", "from mindspore.train.callback import LossMonitor\n", "\n", "# 手动构建训练网络\n", "train_net = CustomTrainOneStepCell(loss_net, opt)\n", "# 定义`Model`并执行训练\n", "model = Model(train_net)\n", "multi_train_ds = create_multilabel_dataset(num_data=160)\n", "model.train(epoch=1, train_dataset=multi_train_ds, callbacks=[LossMonitor()], dataset_sink_mode=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "此时`train_net`即为训练网络。自定义训练网络时,同样需要自定义评估网络,进行模型评估和推理的方式与上一节`手动连接前向网络与损失函数`相同。\n", "\n", "当自定义训练网络的标签和预测值均为单一值时,评价函数不需要特殊处理(自定义或使用`set_indexes`),其他场景仍然需要注意评价指标的正确使用方式。\n", "\n", "### 自定义网络的权重共享\n", "\n", "[构建训练与评估网络](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/train_and_eval.html)中已经介绍过权重共享的机制,使用MindSpore构建不同网络结构时,只要这些网络结构是在同一个实例的基础上封装的,那这个实例中的所有权重便是共享的,一个网络结构中的权重发生变化,意味着其他网络结构中的权重同步发生了变化。\n", "\n", "在使用Model进行训练时,对于简单的场景,`Model`内部使用`nn.WithLossCell`、`nn.TrainOneStepCell`和`nn.WithEvalCell`在前向`network`实例的基础上构建训练和评估网络,`Model`本身确保了推理、训练、评估网络之间权重共享。但对于自定义使用Model的场景,用户需要注意前向网络仅实例化一次。如果构建训练网络和评估网络时分别实例化前向网络,那在使用`eval`进行模型评估时,便需要手动加载训练网络中的权重,否则模型评估使用的将是初始的权重值。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }