mindspore.recompute

查看源文件
mindspore.recompute(block, *args, **kwargs)[源代码]

该函数用于减少显存的使用,当运行选定的模块时,不再保存其中的前向计算的产生的激活值,我们将在反向传播时,重新计算前向的激活值。

说明

  • 重计算函数只支持继承自Cell对象的模块,

  • 该函数当前只支持PyNative模式,在图模式下,可以尝试使用Cell.recompute()接口,

  • 当使用重计算函数时,传入的网络模块不能使用jit装饰器。

参数:
  • block (Cell) - 需要重计算的网络模块。

  • args (tuple) - 指需要重计算的网络模块的前向输入。

  • kwargs (dict) - 可选输入。

返回:

同block的返回类型相同。

异常:
  • TypeError - 如果 block 不是Cell对象。

  • AssertionError - 如果执行模式不是PyNative模式。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> from mindspore import Tensor, recompute
>>> class MyCell(nn.Cell):
...     def __init__(self):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones')
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = recompute(self.conv, x)
...         return self.relu(y)
>>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2)
>>> my_net = MyCell()
>>> grad = ops.grad(my_net)(inputs)
>>> print(grad)
[[[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]
 [[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]]