{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# 自定义训练\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.2/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.2/tutorials/source_zh_cn/intermediate/custom/train.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.2/quick_start/mindspore_train.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvdHV0b3JpYWxzL3poX2NuL21pbmRzcG9yZV90cmFpbi5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c)\n", "\n", "MindSpore提供了`model.train`接口来进行模型训练。使用方式可以参考[初级教程-初学入门](https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html)。此外,还可以使用`TrainOneStepCell`,该接口当前支持GPU、Ascend环境。\n", "\n", "作为高阶接口,`model.train`封装了`TrainOneStepCell`,可以直接利用设定好的网络、损失函数与优化器进行训练。用户也可以选择使用`TrainOneStepCell`实现更加灵活的训练,例如控制训练数据集、实现多输入多输出网络、或自定义训练过程。\n", "\n", "## TrainOneStepCell说明\n", "\n", "`TrainOneStepCell`中包含三种入参:\n", "\n", "- network (Cell):参与训练的网络,当前仅接受单输出网络。\n", "\n", "- optimizer (Cell):所使用的优化器。\n", "\n", "- sens (Number):反向传播的缩放比例。\n", "\n", "下面使用`TrainOneStepCell`替换`model.train`,实现简单的线性网络训练过程。\n", "\n", "## TrainOneStepCell使用示例\n", "\n", "### 创建模型并生成数据\n", "\n", "> 本小节详细解释说明可参考[初级教程-初学入门](https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html)。\n", "\n", "定义网络LinearNet,内部有两层全连接层组成的网络, 包含5个入参和1个出参的神经网络。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore.dataset as ds\n", "from mindspore import ParameterTuple\n", "\n", "class LinearNet(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.relu = nn.ReLU()\n", " self.dense1 = nn.Dense(5, 32)\n", " self.dense2 = nn.Dense(32, 1)\n", "\n", " def construct(self, x):\n", " x = self.dense1(x)\n", " x = self.relu(x)\n", " x = self.dense2(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "产生输入数据。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "np.random.seed(4)\n", "\n", "class DatasetGenerator:\n", " def __init__(self):\n", " self.data = np.random.randn(5, 5).astype(np.float32)\n", " self.label = np.random.randn(5, 1).astype(np.int32)\n", "\n", " def __getitem__(self, index):\n", " return self.data[index], self.label[index]\n", "\n", " def __len__(self):\n", " return len(self.data)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "数据处理。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# 对输入数据进行处理\n", "dataset_generator = DatasetGenerator()\n", "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=True)\n", "dataset = dataset.batch(32)\n", "\n", "# 实例化网络\n", "net = LinearNet()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 定义TrainOneStepCell\n", "\n", "在`TrainOneStepCell`中,可以实现对训练过程的个性化设定。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "class TrainOneStepCell(nn.Cell):\n", " def __init__(self, network, optimizer, sens=1.0):\n", " \"\"\"参数初始化\"\"\"\n", " super(TrainOneStepCell, self).__init__(auto_prefix=False)\n", " self.network = network\n", " # 使用tuple包装weight\n", " self.weights = ParameterTuple(network.trainable_params())\n", " self.optimizer = optimizer\n", " # 定义梯度函数\n", " self.grad = ops.GradOperation(get_by_list=True, sens_param=True)\n", " self.sens = sens\n", "\n", " def construct(self, data, label):\n", " \"\"\"构建训练过程\"\"\"\n", " weights = self.weights\n", " loss = self.network(data, label)\n", " # 为反向传播设定系数\n", " sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)\n", " grads = self.grad(self.network, weights)(data, label, sens)\n", " return loss, self.optimizer(grads)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 网络训练\n", "\n", "在使用`TrainOneStepCell`时,需要利用`WithLossCell`接口引入损失函数,共同完成训练过程。下面利用之前设定好的参数训练LinearNet网络,并获取loss值。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7998974\n", "0.79927444\n", "0.7986423\n", "0.7979911\n", "0.79732\n", "... ...\n", "0.042837422\n", "0.041227795\n", "0.039638687\n", "... ...\n", "9.276913e-06\n", "8.4145695e-06\n", "7.625091e-06\n", "6.904066e-06\n", "6.2513377e-06\n" ] } ], "source": [ "# 设定损失函数\n", "crit = nn.MSELoss()\n", "# 设定优化器\n", "opt = nn.Adam(params=net.trainable_params())\n", "# 引入损失函数\n", "net_with_criterion = nn.WithLossCell(net, crit)\n", "# 自定义网络训练\n", "train_net = TrainOneStepCell(net_with_criterion, opt)\n", "\n", "# 获取训练过程数据\n", "for d in dataset.create_dict_iterator():\n", " for i in range(300):\n", " train_net(d[\"data\"], d[\"label\"])\n", " print(net_with_criterion(d[\"data\"], d[\"label\"]))" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore-1.3.0", "language": "python", "name": "mindspore-1.3.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }