{ "cells": [ { "cell_type": "markdown", "id": "170e84e3", "metadata": {}, "source": [ "# 自定义Cell的反向\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.8/tutorials/experts/zh_cn/network/mindspore_custom_cell_reverse.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.8/tutorials/experts/zh_cn/network/mindspore_custom_cell_reverse.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/experts/source_zh_cn/network/custom_cell_reverse.ipynb)\n", "\n", "用户可以自定义nn.Cell对象的反向传播(计算)函数,从而控制nn.Cell对象梯度计算的过程,定位梯度问题。\n", "\n", "自定义bprop函数的使用方法是:在定义的nn.Cell对象里面增加一个用户自定义的bprop函数。训练的过程中会使用用户自定义的bprop函数来生成反向图。\n", "\n", "示例代码:" ] }, { "cell_type": "code", "execution_count": 5, "id": "05282818", "metadata": { "ExecuteTime": { "end_time": "2022-03-02T09:14:55.896896Z", "start_time": "2022-03-02T09:14:55.881233Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[], dtype=Float32, value= 3), Tensor(shape=[], dtype=Float32, value= 2))\n" ] } ], "source": [ "ms.set_context(mode=ms.PYNATIVE_MODE)\n", "\n", "class Net(nn.Cell):\n", " def construct(self, x, y):\n", " z = x * y\n", " z = z * y\n", " return z\n", "\n", " def bprop(self, x, y, out, dout):\n", " x_dout = x + y\n", " y_dout = x * y\n", " return x_dout, y_dout\n", "\n", "grad_all = ops.GradOperation(get_all=True)\n", "output = grad_all(Net())(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))\n", "print(output)" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 5 }