自定义Cell的反向传播函数

下载Notebook下载样例代码查看源文件

使用MindSpore构建神经网络时,需要继承 nn.Cell 类。构建网络的过程中,我们可能会遇到一些问题,例如:

  1. Cell中存在一些不可求导的或者是尚未定义反向传播规则的操作或算子;

  2. 替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。

这时我们可以使用自定义Cell对象的反向传播函数的功能,形式为:

def bprop(self, ..., out, dout):
    return ...
  • 输入参数:与正向部分相同的输入参数再加上 outdoutout 表示正向部分的计算结果, dout 表示回传到该 nn.Cell 对象的梯度。

  • 返回值:关于正向部分每个输入的梯度,所以返回值的数量需要与正向部分输入的数量相同。

一个简单的完整示例如下:

[1]:
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()

    def construct(self, x, y):
        out = self.matmul(x, y)
        return out

    def bprop(self, x, y, out, dout):
        dx = x + 1
        dy = y + 1
        return dx, dy


class GradNet(nn.Cell):
    def __init__(self, net):
        super(GradNet, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation(get_all=True)

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)


x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)
out = GradNet(Net())(x, y)
print(out)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.50000000e+00,  1.60000002e+00,  1.39999998e+00],
 [ 2.20000005e+00,  2.29999995e+00,  2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.00999999e+00,  1.29999995e+00,  2.09999990e+00],
 [ 1.10000002e+00,  1.20000005e+00,  2.29999995e+00],
 [ 3.09999990e+00,  2.20000005e+00,  4.30000019e+00]]))

此示例通过定义Cell的 bprop 函数,对 MatMul 操作自定义了梯度计算过程,其中 dx 为对输入 x 的导数, dy 为对输入 y 的导数, outMatMul 的计算结果, dout 为回传到 Net 的梯度。

应用样例

  1. Cell中存在一些尚未定义反向传播规则的操作或算子。例如 ReLU6 算子尚未定义其二阶反向传播规则,这时我们可以通过自定义Cell的 bprop 函数去自定义 ReLU6 算子的二阶反向传播规则。代码如下:

[2]:
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
import mindspore.ops as ops


class ReluNet(nn.Cell):
    def __init__(self):
        super(ReluNet, self).__init__()
        self.relu = ops.ReLU()

    def construct(self, x):
        return self.relu(x)


class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.relu6 = ops.ReLU6()
        self.relu = ReluNet()

    def construct(self, x):
        return self.relu6(x)

    def bprop(self, x, out, dout):
        dx = self.relu(x)
        return (dx, )


x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
net = Net()
out = ops.grad(ops.grad(net))(x)
print(out)
[[1. 1. 1.]
 [1. 1. 1.]]

此代码通过自定义 Netbprop 函数,定义了一阶反向传播规则,而二阶反向传播规则通过 bprop 中的 self.relu 的反向传播规则得到。

  1. 替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。例如SNN网络有如下代码:

    class relusigmoid(nn.Cell):
        def __init__(self):
            super().__init__()
            self.sigmoid = ops.Sigmoid()
            self.greater = ops.Greater()
    
        def construct(self, x):
            spike = self.greater(x, 0)
            return spike.astype(mindspore.float32)
    
        def bprop(self, x, out, dout):
            sgax = self.sigmoid(x * 5.0)
            grad_x = dout * (1 - sgax) * sgax * 5.0
            return (grad_x,)
    
    class IFNode(nn.Cell):
        def __init__(self, v_threshold=1.0, fire=True, surrogate_function=relusigmoid()):
            super().__init__()
            self.v_threshold = v_threshold
            self.fire = fire
            self.surrogate_function = surrogate_function
    
        def construct(self, x, v):
            v = v + x
            if self.fire:
                spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold
                v -= spike
                return spike, v
            return v, v
    

    此代码自定义了一个新的激活函数relusigmoid,在子网 IFNode 里去替换原来的sigmoid激活函数,这时候就需要去自定义新的激活函数的反向传播规则。

约束与限制

  • bprop 函数的返回值数量为1时,也需要写成tuple的形式,即 return (dx,)

  • 图模式下, bprop 函数需要转换成图IR,所以需要遵循静态图语法,请参考静态图语法支持

  • 只支持返回关于正向部分输入的梯度,不支持返回关于 Parameter 的梯度。

  • 不支持在 bprop 中使用 Parameter