Customizing bprop Function

View Source On Gitee

Users can customize backpropagation (calculation) function of the nn.Cell object, thus control the process of the nn.Cell object gradient calculation, locating gradient problems.

Custom bprop functions are used by: adding a user-defined bprop function to the defined nn. Cell object. The training process uses user-defined bprop functions to generate reverse graphs.

The sample code is as follows:

ms.set_context(mode=ms.PYNATIVE_MODE)

class Net(nn.Cell):
    def construct(self, x, y):
        z = x * y
        z = z * y
        return z

    def bprop(self, x, y, out, dout):
        x_dout = x + y
        y_dout = x * y
        return x_dout, y_dout

grad_all = ops.GradOperation(get_all=True)
output = grad_all(Net())(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))
print(output)