mindspore.recompute

View Source On Gitee
mindspore.recompute(block, *args, use_reentrant=True, output_recompute=False, **kwargs)[source]

This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.

Note

Recompute function only support block which inherited from Cell object.

Parameters
  • block (Cell) – Block to be recompute.

  • args (tuple) – Inputs for block object to run forward pass.

Keyword Arguments
  • use_reentrant (bool, optional) – This keyword is only valid in PyNative mode. If use_reentrant=True is set, we will implement recomputation through a custom bprop function, which does not support differentiation of complex types such as List/Tuple, If use_reentrant=False is set, we will use the saved_tensors_hook functionality to implement recomputation, which supports differentiation of tensors inside complex types. Default: True.

  • output_recompute (bool, optional) – This keyword is only valid in PyNative mode. If output_recompute=True is set, we will implement recomputation by saved_tensors_hook functionality by default. The output of this cell or function will not be stored by subsequent operations for backward. when there are two adjacent cells both requiring recomputation (where the output of one cell serves as the input to the other), the recomputation of these two cells will be merged. In this case, the output activation values of the first cell will not be stored. If output_recompute=False, we will not merge adjacent cells. Default: False.

  • **kwargs – Other arguments.

Returns

Same as return type of block.

Raises

TypeError – If block is not Cell object.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> from mindspore import Tensor, recompute
>>> class MyCell(nn.Cell):
...     def __init__(self):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones')
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = recompute(self.conv, x)
...         return self.relu(y)
>>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2)
>>> my_net = MyCell()
>>> grad = ops.grad(my_net)(inputs)
>>> print(grad)
[[[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]
 [[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]]