mindspore.recompute

View Source On Gitee
mindspore.recompute(block, *args, **kwargs)[source]

This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.

Note

  • Recompute function only support block which inherited from Cell object.

  • This function interface now only support pynative mode. you can use Cell.recompute interface in graph mode.

  • When use recompute function, block object should not decorated by @jit.

Parameters
  • block (Cell) – Block to be recompute.

  • args (tuple) – Inputs for block object to run forward pass.

  • kwargs (dict) – Optional input for recompute function.

Returns

Same as return type of block.

Raises
Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> from mindspore import Tensor, recompute
>>> class MyCell(nn.Cell):
...     def __init__(self):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones')
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = recompute(self.conv, x)
...         return self.relu(y)
>>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2)
>>> my_net = MyCell()
>>> grad = ops.grad(my_net)(inputs)
>>> print(grad)
[[[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]
 [[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]]