自定义Cell的反向

下载Notebook下载样例代码查看源文件

用户可以自定义nn.Cell对象的反向传播(计算)函数,从而控制nn.Cell对象梯度计算的过程,定位梯度问题。

自定义bprop函数的使用方法是:在定义的nn.Cell对象里面增加一个用户自定义的bprop函数。训练的过程中会使用用户自定义的bprop函数来生成反向图。

示例代码:

[5]:
ms.set_context(mode=ms.PYNATIVE_MODE)

class Net(nn.Cell):
    def construct(self, x, y):
        z = x * y
        z = z * y
        return z

    def bprop(self, x, y, out, dout):
        x_dout = x + y
        y_dout = x * y
        return x_dout, y_dout

grad_all = ops.GradOperation(get_all=True)
output = grad_all(Net())(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))
print(output)
(Tensor(shape=[], dtype=Float32, value= 3), Tensor(shape=[], dtype=Float32, value= 2))