{ "cells": [ { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:06:23.745016Z", "start_time": "2022-03-02T09:06:21.533915Z" } }, "source": [ "# 动态图模式应用\n", "\n", "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svcjEuNy90dXRvcmlhbHMvemhfY24vYWR2YW5jZWQvcHluYXRpdmVfZ3JhcGgvbWluZHNwb3JlX3B5bmF0aXZlLmlweW5i&imageid=9d63f4d1-dc09-4873-b669-3483cea777c0) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.7/tutorials/zh_cn/advanced/pynative_graph/mindspore_pynative.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.7/tutorials/zh_cn/advanced/pynative_graph/mindspore_pynative.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.7/tutorials/source_zh_cn/advanced/pynative_graph/pynative.ipynb)\n", "\n", "## 概述\n", "\n", "MindSpore的针对动态图和静态图两种模式在调试和运行方面做了不同的优化:\n", "\n", "- 动态图模式:也称PyNative模式,将神经网络中的各个算子逐一下发执行,方便用户编写和调试神经网络模型。\n", "- 静态图模式:也称Graph模式或者图模式,将神经网络模型编译成一整张图,然后下发执行。该模式利用图优化等技术提高运行性能,同时有助于规模部署和跨平台运行。\n", "\n", "在动态图模式下,MindSpore支持执行单算子、普通函数和网络,以及单独求梯度的操作,下面我们将通过示例代码详细介绍这几种操作的使用方法和注意事项。\n", "\n", "## 动态图模式下的操作\n", "\n", "首先,我们导入相关依赖,并设置运行模式为动态图模式:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:15:13.240496Z", "start_time": "2022-03-02T09:15:13.237903Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore.ops as ops\n", "import mindspore.nn as nn\n", "from mindspore import Tensor, context\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 执行单算子\n", "\n", "下面为执行加法算子[mindspore.ops.Add](https://mindspore.cn/docs/zh-CN/r1.7/api_python/ops/mindspore.ops.Add.html#mindspore.ops.Add)的示例代码:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:08:29.337515Z", "start_time": "2022-03-02T09:08:29.322592Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x: [1. 2.] \n", "y: [3. 5.] \n", "z: [4. 7.]\n" ] } ], "source": [ "add = ops.Add()\n", "x = Tensor(np.array([1, 2]).astype(np.float32))\n", "y = Tensor(np.array([3, 5]).astype(np.float32))\n", "z = add(x, y)\n", "print(\"x:\", x.asnumpy(), \"\\ny:\", y.asnumpy(), \"\\nz:\", z.asnumpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 执行函数\n", "\n", "执行自定义函数`add_func`,示例代码如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:08:53.065585Z", "start_time": "2022-03-02T09:08:53.058016Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x: [1. 2.] \n", "y: [3. 5.] \n", "z: [5. 9.]\n" ] } ], "source": [ "add = ops.Add()\n", "\n", "def add_func(x, y):\n", " z = add(x, y)\n", " z = add(z, x)\n", " return z\n", "\n", "x = Tensor(np.array([1, 2]).astype(np.float32))\n", "y = Tensor(np.array([3, 5]).astype(np.float32))\n", "z = add_func(x, y)\n", "print(\"x:\", x.asnumpy(), \"\\ny:\", y.asnumpy(), \"\\nz:\", z.asnumpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 执行网络\n", "\n", "执行自定义网络`Net`,在construct中定义网络结构,示例代码如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:09:16.498705Z", "start_time": "2022-03-02T09:09:16.490549Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x: [1. 2. 3.] \n", "y: [4. 5. 6.] \n", "z: [ 4. 10. 18.]\n" ] } ], "source": [ "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.mul = ops.Mul()\n", "\n", " def construct(self, x, y):\n", " return self.mul(x, y)\n", "\n", "net = Net()\n", "x = Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))\n", "y = Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))\n", "z = net(x, y)\n", "\n", "print(\"x:\", x.asnumpy(), \"\\ny:\", y.asnumpy(), \"\\nz:\", z.asnumpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 提升动态图模式性能\n", "\n", "为了提高动态图模式下的前向计算任务执行速度,MindSpore提供了`ms_function`装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。`ms_function`装饰器的使用方式及注意事项已在[动静结合](https://www.mindspore.cn/tutorials/zh-CN/r1.7/advanced/pynative_graph/combine.html#动静结合)章节中说明,此处不再赘述。\n", "\n", "## 动态图模式下同步执行\n", "\n", "在动态图模式下,为了提升性能,算子在device上使用了异步执行方式,因此在算子执行错误的时候,错误信息可能会在程序执行到最后才显示。针对这种情况,MindSpore增加了一个pynative_synchronize的设置来控制算子device上是否使用异步执行。\n", "\n", "动态图模式下算子默认为异步执行,可以通过设置context来控制是否异步执行。当算子执行失败时,可以方便地通过调用栈看到出错的代码位置。示例代码如下:\n", "\n", "```python\n", "import numpy as np\n", "import mindspore.context as context\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import dtype as mstype\n", "import mindspore.ops as ops\n", "\n", "# 通过设置pynative_synchronize来使算子同步执行\n", "context.set_context(mode=context.PYNATIVE_MODE, pynative_synchronize=True)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.get_next = ops.GetNext([mstype.float32], [(1, 1)], 1, \"test\")\n", "\n", " def construct(self, x1,):\n", " x = self.get_next()\n", " x = x + x1\n", " return x\n", "\n", "context.set_context()\n", "x1 = np.random.randn(1, 1).astype(np.float32)\n", "net = Net()\n", "output = net(Tensor(x1))\n", "print(output.asnumpy())\n", "```\n", "\n", "输出:此时算子为同步执行,当算子执行错误时,可以看到完整的调用栈,找到出错的代码行。\n", "\n", "```text\n", "Traceback (most recent call last):\n", " File \"test.py\", line 24, in \n", " output = net(Tensor(x1))\n", " File \".../mindspore/nn/cell.py\", line 602, in __call__\n", " raise err\n", " File \".../mindspore/nn/cell.py\", line 599, in __call__\n", " output = self._run_construct(cast_inputs, kwargs)\n", " File \".../mindspore/nn/cell.py\", line 429, in _run_construct\n", " output = self.construct(*cast_inputs, **kwargs)\n", " File \"test.py\", line 17, in construct\n", " x = self.get_next()\n", " File \".../mindspore/ops/primitive.py\", line 294, in __call__\n", " return _run_op(self, self.name, args)\n", " File \".../mindspore/common/api.py\", line 90, in wrapper\n", " results = fn(*arg, **kwargs)\n", " File \".../mindspore/ops/primitive.py\", line 754, in _run_op\n", " output = real_run_op(obj, op_name, args)\n", "RuntimeError: mindspore/ccsrc/plugin/device/gpu/kernel/data/dataset_iterator_kernel.cc:139 Launch] For 'GetNext', gpu Queue(test) Open Failed: 2\n", "```\n", "\n", "## Hook功能\n", "\n", "调试深度学习网络是每一个深度学习领域的从业者需要面对且投入精力较大的工作。由于深度学习网络隐藏了中间层算子的输入、输出数据以及反向梯度,只提供网络输入数据(特征量、权重)的梯度,导致无法准确地感知中间层算子的数据变化,从而降低了调试效率。为了方便用户准确、快速地对深度学习网络进行调试,MindSpore在动态图模式下设计了Hook功能,**使用Hook功能可以捕获中间层算子的输入、输出数据以及反向梯度**。\n", "\n", "目前,动态图模式下提供了四种形式的Hook功能,分别是:HookBackward算子和在Cell对象上进行注册的register_forward_pre_hook、register_forward_hook、register_backward_hook功能。\n", "\n", "### HookBackward算子\n", "\n", "HookBackward将Hook功能以算子的形式实现。用户初始化一个HookBackward算子,将其安插到深度学习网络中需要捕获梯度的位置。在网络正向执行时,HookBackward算子将输入数据不做任何修改后原样输出;在网络反向传播梯度时,在HookBackward上注册的Hook函数将会捕获反向传播至此的梯度。用户可以在Hook函数中自定义对梯度的操作,比如打印梯度,或者返回新的梯度。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:13:38.660885Z", "start_time": "2022-03-02T09:13:38.645597Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)\n", "output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))\n" ] } ], "source": [ "import mindspore\n", "from mindspore import ops\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def hook_fn(grad_out):\n", " \"\"\"打印梯度\"\"\"\n", " print(\"hook_fn print grad_out:\", grad_out)\n", "\n", "grad_all = GradOperation(get_all=True)\n", "hook = ops.HookBackward(hook_fn)\n", "def hook_test(x, y):\n", " z = x * y\n", " z = hook(z)\n", " z = z * y\n", " return z\n", "\n", "def net(x, y):\n", " return grad_all(hook_test)(x, y)\n", "\n", "output = net(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))\n", "print(\"output:\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更多HookBackward算子的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r1.7/api_python/ops/mindspore.ops.HookBackward.html)。\n", "\n", "### Cell对象的register_forward_pre_hook功能\n", "\n", "用户可以在Cell对象上使用`register_forward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用`ms_function`修饰的Cell对象上不起作用。`register_forward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_pre_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def forward_pre_hook_fn(cell_id, inputs):\n", " print(\"forward inputs: \", inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的cell_id是Cell对象的名称以及ID信息,inputs是正向传入到Cell对象的数据。因此,用户可以使用register_forward_pre_hook函数来捕获网络中某一个Cell对象的正向输入数据。用户可以在Hook函数中自定义对输入数据的操作,比如查看、打印数据,或者返回新的输入数据给当前的Cell对象。如果在Hook函数中对Cell对象的原始输入数据进行计算操作后,再作为新的输入数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "[2.]\n", "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n", "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def forward_pre_hook_fn(cell_id, inputs):\n", " print(\"forward inputs: \", inputs)\n", " input_x = inputs[0]\n", " return input_x\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n", "\n", " def construct(self, x, y):\n", " x = x + y\n", " x = self.relu(x)\n", " return x\n", "\n", "grad = GradOperation(get_all=True)\n", "net = Net()\n", "\n", "x = Tensor(np.ones([1]).astype(np.float32))\n", "y = Tensor(np.ones([1]).astype(np.float32))\n", "\n", "output = net(x, y)\n", "print(output)\n", "gradient = grad(net)(x, y)\n", "print(gradient)\n", "net.handle.remove()\n", "gradient = grad(net)(x, y)\n", "print(gradient)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用户如果在Hook函数中直接返回新创建的数据,而不是返回由原始输入数据经过计算后得到的数据,那么梯度的反向传播将会在该Cell对象上截止。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "(Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def forward_pre_hook_fn(cell_id, inputs):\n", " print(\"forward inputs: \", inputs)\n", " return Tensor(np.ones([1]).astype(np.float32))\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n", "\n", " def construct(self, x, y):\n", " x = x + y\n", " x = self.relu(x)\n", " return x\n", "\n", "grad = GradOperation(get_all=True)\n", "net = Net()\n", "\n", "x = Tensor(np.ones([1]).astype(np.float32))\n", "y = Tensor(np.ones([1]).astype(np.float32))\n", "\n", "gradient = grad(net)(x, y)\n", "print(gradient)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数和 `handle` 对象的 `remove()` 函数。在动态图模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数,那么Cell对象每次运行都将新注册一个Hook函数。\n", "\n", "更多关于Cell对象的 `register_forward_pre_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r1.7/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_pre_hook)。\n", "\n", "### Cell对象的register_forward_hook功能\n", "\n", "用户可以在Cell对象上使用`register_forward_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用`ms_function`修饰的Cell对象上不起作用。`register_forward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def forward_hook_fn(cell_id, inputs, outputs):\n", " print(\"forward inputs: \", inputs)\n", " print(\"forward outputs: \", outputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的`cell_id`是Cell对象的名称以及ID信息,`inputs`是正向传入到Cell对象的数据,`outputs`是Cell对象的正向输出数据。因此,用户可以使用`register_forward_hook`函数来捕获网络中某一个Cell对象的正向输入数据和输出数据。用户可以在Hook函数中自定义对输入、输出数据的操作,比如查看、打印数据,或者返回新的输出数据。如果在Hook函数中对Cell对象的原始输出数据进行计算操作后,再作为新的输出数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "forward outputs: [2.]\n", "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n", "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def forward_hook_fn(cell_id, inputs, outputs):\n", " print(\"forward inputs: \", inputs)\n", " print(\"forward outputs: \", outputs)\n", " outputs = outputs + outputs\n", " return outputs\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.handle = self.relu.register_forward_hook(forward_hook_fn)\n", "\n", " def construct(self, x, y):\n", " x = x + y\n", " x = self.relu(x)\n", " return x\n", "\n", "grad = GradOperation(get_all=True)\n", "net = Net()\n", "\n", "x = Tensor(np.ones([1]).astype(np.float32))\n", "y = Tensor(np.ones([1]).astype(np.float32))\n", "\n", "gradient = grad(net)(x, y)\n", "print(gradient)\n", "net.handle.remove()\n", "gradient = grad(net)(x, y)\n", "print(gradient)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用户如果在Hook函数中直接返回新创建的数据,而不是将原始的输出数据经过计算后,将得到的新输出数据返回,那么梯度的反向传播将会在该Cell对象上截止。该现象可以参考`register_forward_pre_hook`函数的用例说明。\n", "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的`construct`函数中调用`register_forward_hook`函数和`handle`对象的`remove()`函数。在动态图模式下,如果在Cell对象的`construct`函数中调用`register_forward_hook`函数,那么Cell对象每次运行都将新注册一个Hook函数。\n", "\n", "更多关于Cell对象的`register_forward_hook`功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r1.7/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_hook)。\n", "\n", "### Cell对象的register_backward_hook功能\n", "\n", "用户可以在Cell对象上使用`register_backward_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`ms_function`修饰的Cell对象上不起作用。`register_backward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_hook`函数,都会返回一个不同的`handle`对象。\n", "\n", "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_hook`使用的Hook函数的入参中,包含了表示Cell对象名称与id信息的`cell_id`、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def backward_hook_function(cell_id, grad_input, grad_output):\n", " print(grad_input)\n", " print(grad_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的`cell_id`是Cell对象的名称以及ID信息,`grad_input`是网络反向传播时,传入到Cell对象的梯度,它对应于正向过程中下一个算子的反向输出梯度;`grad_output`是Cell对象反向输出的梯度。因此,用户可以使用`register_backward_hook`函数来捕获网络中某一个Cell对象的反向传入和反向输出梯度。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输出梯度。如果需要在Hook函数中返回新的输出梯度时,返回值必须是`tuple`的形式。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:14:26.523389Z", "start_time": "2022-03-02T09:14:26.506784Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n", "[[[[ 1.00000000e+00]],\n", " [[ 1.00000000e+00]]]]),)\n", "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n", "[[[[ 9.99994993e-01]],\n", " [[ 9.99994993e-01]]]]),)\n", "(Tensor(shape=[1, 1, 2, 2], dtype=Float32, value=\n", "[[[[ 1.99998999e+00, 1.99998999e+00],\n", " [ 1.99998999e+00, 1.99998999e+00]]]]),)\n", "-------------\n", " (Tensor(shape=[1, 1, 2, 2], dtype=Float32, value=\n", "[[[[ 1.99998999e+00, 1.99998999e+00],\n", " [ 1.99998999e+00, 1.99998999e+00]]]]),)\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def backward_hook_function(cell_id, grad_input, grad_output):\n", " print(grad_input)\n", " print(grad_output)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init=\"ones\", pad_mode=\"valid\")\n", " self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init=\"ones\")\n", " self.handle = self.bn.register_backward_hook(backward_hook_function)\n", " self.relu = nn.ReLU()\n", "\n", " def construct(self, x):\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " return x\n", "\n", "net = Net()\n", "grad_all = GradOperation(get_all=True)\n", "output = grad_all(net)(Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n", "print(output)\n", "net.handle.remove()\n", "output = grad_all(net)(Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n", "print(\"-------------\\n\", output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当 `register_backward_hook` 函数和 `register_forward_pre_hook` 函数、 `register_forward_hook` 函数同时作用于同一Cell对象时,如果 `register_forward_pre_hook` 和 `register_forward_hook` 函数中有添加其他算子进行数据处理,这些新增算子会在Cell对象执行前或者执行后参与数据的正向计算,但是这些新增算子的反向梯度不在 `register_backward_hook` 函数的捕获范围内。 `register_backward_hook` 中注册的Hook函数仅捕获原始Cell对象的输入、输出梯度。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "forward outputs: [2.]\n", "grad input: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "grad output: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n", "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "def forward_pre_hook_fn(cell_id, inputs):\n", " print(\"forward inputs: \", inputs)\n", " input_x = inputs[0]\n", " return input_x\n", "\n", "def forward_hook_fn(cell_id, inputs, outputs):\n", " print(\"forward inputs: \", inputs)\n", " print(\"forward outputs: \", outputs)\n", " outputs = outputs + outputs\n", " return outputs\n", "\n", "def backward_hook_fn(cell_id, grad_input, grad_output):\n", " print(\"grad input: \", grad_input)\n", " print(\"grad output: \", grad_output)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n", " self.handle2 = self.relu.register_forward_hook(forward_hook_fn)\n", " self.handle3 = self.relu.register_backward_hook(backward_hook_fn)\n", "\n", " def construct(self, x, y):\n", " x = x + y\n", " x = self.relu(x)\n", " return x\n", "\n", "net = Net()\n", "grad = GradOperation(get_all=True)\n", "gradient = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))\n", "print(gradient)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的 `grad_input` 是梯度反向传播时传入`self.relu`的梯度,而不是传入 `forward_hook_fn` 函数中,新增的 `Add` 算子的梯度。这里的 `grad_output` 是梯度反向传播时 `self.relu` 反向输出的梯度,而不是 `forward_pre_hook_fn` 函数中新增 `Add` 算子的反向输出梯度。 `register_forward_pre_hook` 函数和 `register_forward_hook` 函数是在Cell对象执行前后起作用,不会影响Cell对象上反向Hook函数的梯度捕获范围。\n", "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数和 `handle` 对象的 `remove()` 函数。在PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数,那么Cell对象每次运行都将新注册一个Hook函数。\n", "\n", "更多关于Cell对象的 `register_backward_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r1.7/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_backward_hook)。\n", "\n", "## 自定义bprop功能\n", "\n", "用户可以自定义nn.Cell对象的反向传播(计算)函数,从而控制nn.Cell对象梯度计算的过程,定位梯度问题。自定义bprop函数的使用方法是:在定义的nn.Cell对象里面增加一个用户自定义的bprop函数。训练的过程中会使用用户自定义的bprop函数来生成反向图。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:14:55.896896Z", "start_time": "2022-03-02T09:14:55.881233Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[], dtype=Float32, value= 3), Tensor(shape=[], dtype=Float32, value= 2))\n" ] } ], "source": [ "import mindspore\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "from mindspore import context\n", "from mindspore.ops import GradOperation\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "\n", "class Net(nn.Cell):\n", " def construct(self, x, y):\n", " z = x * y\n", " z = z * y\n", " return z\n", "\n", " def bprop(self, x, y, out, dout):\n", " x_dout = x + y\n", " y_dout = x * y\n", " return x_dout, y_dout\n", "\n", "grad_all = GradOperation(get_all=True)\n", "output = grad_all(Net())(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))\n", "print(output)" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 5 }