{ "cells": [ { "cell_type": "markdown", "id": "170e84e3", "metadata": {}, "source": [ "# 自定义Cell的反向传播函数\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/experts/zh_cn/network/mindspore_custom_cell_reverse.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/experts/zh_cn/network/mindspore_custom_cell_reverse.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.9/tutorials/experts/source_zh_cn/network/custom_cell_reverse.ipynb)\n", "\n", "使用MindSpore构建神经网络时,需要继承 `nn.Cell` 类。构建网络的过程中,我们可能会遇到一些问题,例如:\n", "\n", "1. Cell中存在一些不可求导的或者是尚未定义反向传播规则的操作或算子;\n", "2. 替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。\n", "\n", "这时我们可以使用自定义Cell对象的反向传播函数的功能,形式为:\n", "\n", "```python\n", "def bprop(self, ..., out, dout):\n", " return ...\n", "```\n", "\n", "- 输入参数:与正向部分相同的输入参数再加上 `out` 和 `dout` , `out` 表示正向部分的计算结果, `dout` 表示回传到该 `nn.Cell` 对象的梯度。\n", "- 返回值:关于正向部分每个输入的梯度,所以返回值的数量需要与正向部分输入的数量相同。\n", "\n", "一个简单的完整示例如下:" ] }, { "cell_type": "code", "execution_count": 1, "id": "05282818", "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:14:55.896896Z", "start_time": "2022-03-02T09:14:55.881233Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[2, 3], dtype=Float32, value=\n", "[[ 1.50000000e+00, 1.60000002e+00, 1.39999998e+00],\n", " [ 2.20000005e+00, 2.29999995e+00, 2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=\n", "[[ 1.00999999e+00, 1.29999995e+00, 2.09999990e+00],\n", " [ 1.10000002e+00, 1.20000005e+00, 2.29999995e+00],\n", " [ 3.09999990e+00, 2.20000005e+00, 4.30000019e+00]]))\n" ] } ], "source": [ "import mindspore.nn as nn\n", "import mindspore as ms\n", "import mindspore.ops as ops\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def construct(self, x, y):\n", " out = self.matmul(x, y)\n", " return out\n", "\n", " def bprop(self, x, y, out, dout):\n", " dx = x + 1\n", " dy = y + 1\n", " return dx, dy\n", "\n", "\n", "class GradNet(nn.Cell):\n", " def __init__(self, net):\n", " super(GradNet, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation(get_all=True)\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "\n", "x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)\n", "y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)\n", "out = GradNet(Net())(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "id": "e87f016e", "metadata": {}, "source": [ "此示例通过定义Cell的 `bprop` 函数,对 `MatMul` 操作自定义了梯度计算过程,其中 `dx` 为对输入 `x` 的导数, `dy` 为对输入 `y` 的导数, `out` 为 `MatMul` 的计算结果, `dout` 为回传到 `Net` 的梯度。\n", "\n", "## 应用样例\n", "\n", "1. Cell中存在一些尚未定义反向传播规则的操作或算子。例如 `ReLU6` 算子尚未定义其二阶反向传播规则,这时我们可以通过自定义Cell的 `bprop` 函数去自定义 `ReLU6` 算子的二阶反向传播规则。代码如下:" ] }, { "cell_type": "code", "execution_count": 2, "id": "2671ebe0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1. 1. 1.]\n", " [1. 1. 1.]]\n" ] } ], "source": [ "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import dtype as mstype\n", "import mindspore.ops as ops\n", "\n", "\n", "class ReluNet(nn.Cell):\n", " def __init__(self):\n", " super(ReluNet, self).__init__()\n", " self.relu = ops.ReLU()\n", "\n", " def construct(self, x):\n", " return self.relu(x)\n", "\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu6 = ops.ReLU6()\n", " self.relu = ReluNet()\n", "\n", " def construct(self, x):\n", " return self.relu6(x)\n", "\n", " def bprop(self, x, out, dout):\n", " dx = self.relu(x)\n", " return (dx, )\n", "\n", "\n", "x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)\n", "net = Net()\n", "out = ops.grad(ops.grad(net))(x)\n", "print(out)" ] }, { "cell_type": "markdown", "id": "0d889eb3", "metadata": {}, "source": [ " 此代码通过自定义 `Net` 的 `bprop` 函数,定义了一阶反向传播规则,而二阶反向传播规则通过 `bprop` 中的 `self.relu` 的反向传播规则得到。\n", "\n", "2. 替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。例如SNN网络有如下代码:\n", "\n", " ```python\n", " class relusigmoid(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.sigmoid = ops.Sigmoid()\n", " self.greater = ops.Greater()\n", "\n", " def construct(self, x):\n", " spike = self.greater(x, 0)\n", " return spike.astype(mindspore.float32)\n", "\n", " def bprop(self, x, out, dout):\n", " sgax = self.sigmoid(x * 5.0)\n", " grad_x = dout * (1 - sgax) * sgax * 5.0\n", " return (grad_x,)\n", "\n", " class IFNode(nn.Cell):\n", " def __init__(self, v_threshold=1.0, fire=True, surrogate_function=relusigmoid()):\n", " super().__init__()\n", " self.v_threshold = v_threshold\n", " self.fire = fire\n", " self.surrogate_function = surrogate_function\n", "\n", " def construct(self, x, v):\n", " v = v + x\n", " if self.fire:\n", " spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold\n", " v -= spike\n", " return spike, v\n", " return v, v\n", " ```\n", "\n", " 此代码自定义了一个新的激活函数relusigmoid,在子网 `IFNode` 里去替换原来的sigmoid激活函数,这时候就需要去自定义新的激活函数的反向传播规则。\n", "\n", "## 约束与限制\n", "\n", "- 当 `bprop` 函数的返回值数量为1时,也需要写成tuple的形式,即 `return (dx,)` 。\n", "- 图模式下, `bprop` 函数需要转换成图IR,所以需要遵循静态图语法,请参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r1.9/note/static_graph_syntax_support.html)。\n", "- 只支持返回关于正向部分输入的梯度,不支持返回关于 `Parameter` 的梯度。\n", "- 不支持在 `bprop` 中使用 `Parameter` 。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 5 }