{ "cells": [ { "cell_type": "markdown", "id": "2042996d", "metadata": {}, "source": [ "# 优化器迁移指南\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/zh_cn/migration_guide/mindspore_optim.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/zh_cn/migration_guide/mindspore_optim.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/docs/mindspore/source_zh_cn/migration_guide/optim.ipynb)\n", "\n", "## 概述\n", "\n", "优化器在模型训练过程中,用于计算和更新网络参数,本文对比MindSpore和PyTorch的在这一部分的实现方式差异,分别从基本用法、基类入参设置及支持的方法、自定义优化器、API映射四部分展开。\n", "\n", "## 基本用法\n", "\n", "MindSpore:使用优化器时,通常需要预先定义网络、损失函数和优化器:" ] }, { "cell_type": "code", "execution_count": 1, "id": "636abd8e", "metadata": { "ExecuteTime": { "end_time": "2022-05-31T06:38:05.064201Z", "start_time": "2022-05-31T06:38:05.039762Z" } }, "outputs": [], "source": [ "import mindspore as ms\n", "from mindspore import nn, ops\n", "import numpy as np\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv = nn.Conv2d(3, 64, 3)\n", " self.bn = nn.BatchNorm2d(64)\n", " def construct(self, x):\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " return x\n", "\n", "net = Net()\n", "loss = nn.MSELoss()\n", "optimizer = nn.SGD(params=net.trainable_params(), learning_rate=0.01)" ] }, { "cell_type": "markdown", "id": "d34feac3", "metadata": {}, "source": [ "在MindSpore中,定义好网络、损失函数、优化器后,一般在以下三种场景下使用:\n", "\n", "- MindSpore封装了`Model`高阶API来方便用户定义和训练网络,在定义`Model`时指定优化器;" ] }, { "cell_type": "code", "execution_count": 2, "id": "662c6282", "metadata": { "ExecuteTime": { "end_time": "2022-05-31T06:38:09.659966Z", "start_time": "2022-05-31T06:38:09.653776Z" } }, "outputs": [], "source": [ "model = ms.Model(net, loss_fn=loss, optimizer=optimizer, metrics={\"accuracy\"})" ] }, { "cell_type": "markdown", "id": "58dd28c7", "metadata": {}, "source": [ "- MindSpore提供了`TrainOneStepCell`接口,通过传入优化器和一个`WithLossCell`的实例,自定义训练网络;" ] }, { "cell_type": "code", "execution_count": 3, "id": "1d9d6e69", "metadata": { "ExecuteTime": { "end_time": "2022-05-31T06:38:12.716746Z", "start_time": "2022-05-31T06:38:12.302393Z" } }, "outputs": [], "source": [ "# 使用TrainOneStepCell自定义网络\n", "loss_net = nn.WithLossCell(net, loss) # 包含损失函数的Cell\n", "train_net = nn.TrainOneStepCell(loss_net, optimizer)\n", "train_dataset = [(ms.Tensor(np.random.rand(1, 3, 64, 32), ms.float32), ms.Tensor(np.random.rand(1, 64, 64, 32), ms.float32))]\n", "for i in range(5):\n", " for image, label in train_dataset:\n", " train_net.set_train()\n", " res = train_net(image, label) # 执行网络的单步训练" ] }, { "cell_type": "markdown", "id": "1475af7e", "metadata": {}, "source": [ "- 在PyNative模式下,实现单步执行优化器。" ] }, { "cell_type": "code", "execution_count": 5, "id": "a6e3383f", "metadata": { "ExecuteTime": { "end_time": "2022-05-31T06:38:33.058693Z", "start_time": "2022-05-31T06:38:32.895901Z" } }, "outputs": [], "source": [ "# pynative模式下,单步实现GradOperation求梯度,并执行优化器\n", "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"CPU\")\n", "\n", "class GradWrap(nn.Cell):\n", " \"\"\" GradWrap definition \"\"\"\n", " def __init__(self, network):\n", " super(GradWrap, self).__init__(auto_prefix=False)\n", " self.network = network\n", " self.weights = ms.ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))\n", "\n", " def construct(self, x, label):\n", " weights = self.weights\n", " return ops.GradOperation(get_by_list=True)(self.network, weights)(x, label)\n", "\n", "loss_net = nn.WithLossCell(net, loss)\n", "train_network = GradWrap(loss_net)\n", "\n", "output = net(image)\n", "loss_output = loss(output, label)\n", "grads = train_network(image, label)\n", "success = optimizer(grads)" ] }, { "cell_type": "markdown", "id": "28376f46", "metadata": {}, "source": [ "PyTorch:PyTorch为`Tensor`建立了`grad`属性和`backward`方法,`tensor.grad`是通过`tensor.backward`方法(本质是`torch.autograd.backward`)计算的,且在计算中进行梯度值累加,因此一般在调用`tensor.backward`方法前,需要手动将`grad`属性清零。MindSpore没有为`Tensor`和`grad`建立直接联系,在使用时不需要手动清零。\n", "\n", "在下面的代码中,初始化了一个优化器实例,每次循环调用`zero_grad`清零梯度,`backward`更新梯度,`step`更新网络参数,返回损失值。\n", "\n", "```python\n", "import torch\n", "from torch import optim, nn\n", "import numpy as np\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv = nn.Conv2d(3, 64, 3)\n", " self.bn = nn.BatchNorm2d(64)\n", " self.relu = nn.ReLU()\n", " def forward(self, x):\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " return x\n", "\n", "model = Net()\n", "optimizer = optim.SGD(model.parameters(), lr=0.01)\n", "loss_fn = nn.MSELoss()\n", "train_dataset = [(torch.tensor(np.random.rand(1, 3, 64, 32).astype(np.float32)), torch.tensor(np.random.rand(1, 64, 62, 30).astype(np.float32)))]\n", "\n", "for epoch in range(5):\n", " for image, label in train_dataset:\n", " optimizer.zero_grad()\n", " output = model(image)\n", " loss = loss_fn(output, label)\n", " loss.backward()\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "id": "915a50d5", "metadata": {}, "source": [ "## 基类入参设置及支持的方法\n", "\n", "### 基类入参\n", "\n", "MindSpore:\n", "\n", "```python\n", "optimizer(learning_rate, parameters, weight_decay=0.0, loss_scale=1.0)\n", "```" ] }, { "cell_type": "markdown", "id": "6c8bf483", "metadata": {}, "source": [ "PyTorch:\n", "\n", "```python\n", "optimizer(params, defaults)\n", "```" ] }, { "cell_type": "markdown", "id": "bc4c48e3", "metadata": {}, "source": [ "#### 网络中需要被训练的参数\n", "\n", "MindSpore和PyTorch的优化器都需要传入网络中需要被训练的参数,且参数的设置同时都支持默认接口和用户自定义设置两种方式。\n", "\n", "- 默认接口:\n", "\n", " MindSpore的`parameter`包含了网络中所有的参数,通过`require_grad`属性来区分是否需要训练和优化。`trainable_params`方法返回一个`filter`的`list`,筛选了网络中`require_grad`属性为True的`parameter`。\n", "\n", " ```python\n", " from mindspore import nn\n", " optim_sgd = nn.SGD(net.trainable_params())\n", " ```\n", "\n", " PyTorch的`state`包含了网络中所有的参数,其中需要被优化的是`parameter`,不需要优化的是`buffer`(例如:BatchNorm中的`running_mean`和`running_var` )。`parameters`方法返回需要被优化参数的`generator`。\n", "\n", " ```python\n", " from torch import nn, optim\n", " optim_sgd = optim.SGD(params=model.parameters(), lr=0.01)\n", " ```\n", "\n", "- 用户自定义:\n", "\n", " MindSpore和PyTorch都支持用户自定义传入需要优化的参数,例如,对非卷积参数进行训练和优化。代码样例如下:\n", "\n", " ```python\n", " from mindspore import nn\n", "\n", " net = Net()\n", " all_params = net.get_parameters()\n", " non_conv_params = list(filter(lambda x: \"conv\" not in x.name, all_params))\n", " optim_sgd = nn.SGD(params=non_conv_params)\n", " ```\n", "\n", " ```python\n", " from torch import optim\n", "\n", " net = Net()\n", " all_params = model.named_parameters()\n", " target_params = []\n", " for name, params in all_params:\n", " if \"conv\" in name:\n", " target_params.append(params)\n", " optim_sgd = optim.SGD(params=target_params, lr=0.01)\n", " ```" ] }, { "cell_type": "markdown", "id": "263931e4", "metadata": {}, "source": [ "#### 学习率\n", "\n", "使用固定学习率时,用法相同,传入固定值即可;使用动态学习率时,MindSpore和PyTorch都支持动态学习率调整策略,实现方式略有不同。\n", "\n", "- MindSpore:动态学习率有两种实现方式,预生成列表`mindspore.nn.dynamic_lr`和计算图格式`mindspore.nn.learning_rate_schedule`,且动态学习率实例作为优化器的参数输入。以预生成学习率列表的`piecewise_constant_lr`为例:" ] }, { "cell_type": "code", "execution_count": 14, "id": "9b68fd1a", "metadata": { "ExecuteTime": { "end_time": "2022-05-31T06:42:37.519171Z", "start_time": "2022-05-31T06:42:37.514573Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]\n" ] } ], "source": [ "from mindspore import nn\n", "\n", "milestone = [2, 5, 10]\n", "learning_rates = [0.1, 0.05, 0.01]\n", "lr = nn.dynamic_lr.piecewise_constant_lr(milestone, learning_rates)\n", "print(lr)" ] }, { "cell_type": "markdown", "id": "60c12361", "metadata": {}, "source": [ "- PyTorch:优化器作为`lr_scheduler`的输入,调用`step`方法对学习率进行更新。\n", "\n", " ```python\n", " from torch import optim\n", "\n", " model = Net()\n", " optimizer = optim.SGD(model.parameters(), 0.1)\n", " scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)\n", "\n", " for epoch in range(5):\n", " for input, target in train_dataset:\n", " optimizer.zero_grad()\n", " output = model(input)\n", " loss = loss_fn(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", " print(scheduler.get_last_lr())\n", "\n", " # out:\n", " # [0.09000000000000001]\n", " # [0.08100000000000002]\n", " # [0.07290000000000002]\n", " # [0.06561000000000002]\n", " # [0.05904900000000002]\n", " ```\n", "\n", "调整策略映射表\n", "\n", "| mindspore.nn.dynamic_lr | mindspore.nn.learning_rate_schedule | torch.optim.lr_scheduler |\n", "|:-------|:------|:-------------------|\n", "| `piecewise_constant_lr`:分段不变 | / | `StepLR`: 每隔step_size个epoch,学习率乘gamma;`MultiStepLR`: epoch为milestones的时候学习率乘️gamma\n", "|`exponential_decay_lr`:指数衰减| `ExponentialDecayLR`:指数衰减 | `ExponentialLR`: 指数衰减,lr = lr * (学习率乘gamma^epoch)\n", "| `natural_exp_decay_lr`:自然指数衰减 | `NaturalExpDecayLR`:自然指数衰减 | /\n", "| `inverse_decay_lr`:反时间衰减 | `InverseDecayLR`:反时间衰减 | /\n", "| `cosine_decay_lr`:余弦衰减|`CosineDecayLR`:余弦衰减 | `CosineAnnealingLR`: 余弦衰减\n", "|`polynomial_decay_lr`:多项式衰减 | `PolynomialDecayLR`:多项式衰减 | /\n", "| /|/ | `CosineAnnealingWarmRestarts`:周期变化余弦衰减\n", "| /|/ | `CyclicLR/OneCycleLR`:三角循环\n", "| /|/ | `ReduceLROnPlateau`:自适应调整\n", "| /|/ | `LambdaLR`:传入Lambda函数,自定义调整\n", "| /|/ | `MultiplicativeLR`:乘上lr_lambda中设置的数值" ] }, { "cell_type": "markdown", "id": "ad7cbe20", "metadata": {}, "source": [ "#### weight decay\n", "\n", "用法相同。一般情况下,weight_decay取值范围为\\[0, 1\\),实现对需要优化的参数使用权重衰减的策略,以避免模型过拟合问题;weight_decay的默认值为0.0,此时不使用权重衰减策略。\n", "\n", "#### 参数分组\n", "\n", "MindSpore和PyTorch都支持参数分组且使用方法相似,在使用时都是给优化器传入一个字典的列表,每个字典对应一个参数组,其中key为参数名,value为对应的设置值。不同点是,MindSpore只支持对“lr”,“weight_decay”,“grad_centralizaiton”实现分组,pytoch支持对所有参数进行分组。此外,PyTorch还支持`add_param_group`方法,对参数组进行添加和管理。\n", "\n", "> MindSpore和PyTorch各自有部分优化器不支持参数分组,请参考具体优化器的实现。\n", "\n", "MindSpore参数分组用法请参考[超参分组](https://www.mindspore.cn/tutorials/zh-CN/r1.8/advanced/network/optim.html#超参分组);PyTorch参数分组用法参考下述样例:\n", "\n", "```python\n", "from torch import optim\n", "\n", "net = Net()\n", "all_params = net.parameters()\n", "conv_params = []\n", "non_conv_params = []\n", "# 根据自己的筛选规则 将所有网络参数进行分组\n", "for pname, p in model.named_parameters():\n", " if ('conv' in pname):\n", " conv_params += [p]\n", " else:\n", " non_conv_params += [p]\n", "\n", "print(len(conv_params), len(non_conv_params))\n", "# 构建不同学习参数的优化器\n", "optimizer = torch.optim.SGD([\n", " {'params': conv_params, 'lr': 0.02},\n", " {'params': non_conv_params, 'weight_decay': 0.5}],\n", " lr=0.01, momentum=0.9)\n", "\n", "# out: 2 2\n", "```" ] }, { "cell_type": "markdown", "id": "d7dc7502", "metadata": {}, "source": [ "#### 混合精度\n", "\n", "MindSpore中的混合精度场景下,如果使用`FixedLossScaleManager`进行溢出检测,且`drop_overflow_update`为False时,优化器需设置`loss_scale`的值,且`loss_scale`值与`FixedLossScaleManager`的相同,详细使用方法可以参考[优化器的混合精度配置](https://www.mindspore.cn/tutorials/zh-CN/r1.8/advanced/network/optim.html#配置优化器)。PyTorch的混合精度设置不作为优化器入参。\n", "\n", "### 基类支持的方法\n", "\n", "#### 获取LR\n", "\n", "`torch.optim.lr_scheduler.get_last_lr()`:根据参数组返回对应的最新学习率数值的列表。\n", "\n", "mindspore中没有直接可以按照组别获取对应学习率的功能,但提供了以下方法辅助使用:\n", "\n", "- `mindspore.nn.optimizer.get_lr()`:获取当前step的学习率,可以在自定义优化器时,在construct方法中使用。\n", "\n", "- `mindspore.nn.optimizer.get_lr_parameter(params)`:获取指定参数组的参数学习率列表,如果是固定学习率,返回一个标量Parameter的列表;如果是计算图格式的动态学习率,返回一个Cell的列表;如果是列表格式的动态学习率,返回shape为(n,)的Parameter的列表(其中n是动态学习率列表的长度)。\n", "\n", "#### 获取优化器的状态\n", "\n", "`torch.optimizer.param_groups`:获取优化器相关配置参数的状态,返回数据格式为字典的列表,key为参数名,value为参数值。以SGD为例,字典的key为key为'params'、 'lr'、'momentum'、'dampening'、'weight_decay'、 'nesterov'等。\n", "\n", "`torch.optimizer.state_dict`:获取optimizer的状态,返回一个key为“state”、“param_groups”,value为对应数值的字典。\n", "\n", "MindSpore暂无对应功能。\n", "\n", "## 自定义优化器\n", "\n", "MindSpore和PyTorch都支持用户基于python基本语法及相关算子自定义优化器。在PyTorch中,通过重写`__init__`和`step`方法,用户可以根据需求自定义优化器,具体用法可以参考[这篇教程](http://mcneela.github.io/machine_learning/2019/09/03/Writing-Your-Own-Optimizers-In-Pytorch.html)。MindSpore也支持类似用法,以Momentum为例,使用基础的小算子构建:" ] }, { "cell_type": "code", "execution_count": null, "id": "90ddf7b8", "metadata": {}, "outputs": [], "source": [ "from mindspore import ops, nn\n", "\n", "class MomentumOpt(nn.Optimizer):\n", " def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):\n", " super(MomentumOpt, self).__init__(learning_rate, params, weight_decay, loss_scale)\n", " self.momentum = ms.Parameter(ms.Tensor(momentum, ms.float32), name=\"momentum\")\n", " self.moments = self.parameters.clone(prefix=\"moments\", init=\"zeros\")\n", " self.assign = ops.Assign()\n", " def construct(self, gradients):\n", " params = self.parameters\n", " moments = self.moments\n", " success = None\n", " for param, mom, grad in zip(params, moments, gradients):\n", " # 小算子表达\n", " update = self.momentum * param + mom + self.learning_rate * grad\n", " success = self.assign(param, update)\n", " return success" ] }, { "cell_type": "markdown", "id": "04f64abf", "metadata": {}, "source": [ "MindSpore的`ops`模块也提供了`ApplyMomentum`的高阶算子,使用方式可参考:" ] }, { "cell_type": "code", "execution_count": null, "id": "be86e9e4", "metadata": {}, "outputs": [], "source": [ "from mindspore import ops, nn\n", "\n", "class MomentumOpt(nn.Optimizer):\n", " def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):\n", " super(MomentumOpt, self).__init__(learning_rate, params, weight_decay, loss_scale)\n", " self.moments = self.parameters.clone(prefix=\"moments\", init=\"zeros\")\n", " self.opt = ops.ApplyMomentum(use_nesterov=use_nesterov)\n", " def construct(self, gradients):\n", " params = self.parameters\n", " moments = self.moments\n", " success = None\n", " for param, mom, grad in zip(params, moments, gradients):\n", " # 大算子表达\n", " success = self.opt(param, mom, self.learning_rate, grad, self.momentum)\n", " return success" ] }, { "cell_type": "markdown", "id": "7e4ac030", "metadata": {}, "source": [ "## API映射\n", "\n", "MindSpore和PyTorch的API对应关系和差异可以参考[API映射](https://www.mindspore.cn/docs/zh-CN/r1.8/note/api_mapping/pytorch_api_mapping.html),其余暂时没有对应关系的接口目前情况如下:\n", "\n", "```python\n", "# PyTorch\n", "torch.optim.ASGD\n", "torch.optim.LBFGS\n", "```\n", "\n", "```python\n", "# mindspore\n", "mindspore.nn.ProximalAadagrad\n", "mindspore.nn.AdamOffload\n", "mindspore.nn.FTRL\n", "mindspore.nn.Lamb\n", "mindspore.nn.thor\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }