mindspore.recompute
- mindspore.recompute(block, *args, use_reentrant=True, output_recompute=False, early_stop=True, context_fn=None, **kwargs)[source]
This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
Recompute function only support block which inherited from Cell object.
- Parameters
- Keyword Arguments
use_reentrant (bool, optional) – This keyword is only valid in PyNative mode. If use_reentrant=True is set, we will implement recomputation through a custom bprop function, which does not support differentiation of complex types such as List/Tuple, If use_reentrant=False is set, we will use the saved_tensors_hook functionality to implement recomputation, which supports differentiation of tensors inside complex types. Default:
True.output_recompute (bool, optional) – If output_recompute=True is set, we will implement recomputation by saved_tensors_hook functionality by default. The output of this cell or function will not be stored by subsequent operations for backward. When there are two adjacent cells both requiring recomputation (where the output of one cell serves as the input to the other), the recomputation of these two cells will be merged. In this case, the output activation values of the first cell will not be stored. If output_recompute=False, we will not merge adjacent cells. Default:
False.early_stop (bool, optional) – This keyword is only valid in PyNative mode when use_reentrant=False. If
True, non-reentrant recompute stops recomputation as soon as it has computed all needed Tensors. This can reduce unnecessary computation when the forward function contains operations that don't save tensors at the end (e.g., clone, view operations). This argument is ignored ifuse_reentrant=True. Note: Higher-order differentiation is only supported whenearly_stop=False. Default:True.context_fn (Callable, optional) – A callable returning a tuple of two context managers. The first context manager is applied during the forward pass, and the second is applied during recomputation. This is useful for applying different settings (e.g., enabling/disabling gradient computation, mixed precision) during forward and recomputation phases. This argument is ignored if
use_reentrant=True. Default:None, returns two null contexts.**kwargs – Other arguments.
- Returns
Same as return type of block.
- Raises
TypeError – If block is not Cell object.
- Supported Platforms:
AscendGPUCPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import ops >>> from mindspore import Tensor, recompute >>> class MyCell(nn.Cell): ... def __init__(self): ... super(MyCell, self).__init__(auto_prefix=False) ... self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones') ... self.relu = ops.ReLU() ... ... def construct(self, x): ... y = recompute(self.conv, x) ... return self.relu(y) >>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2) >>> my_net = MyCell() >>> grad = ops.grad(my_net)(inputs) >>> print(grad) [[[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]] [[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]]]