mindspore.recompute
- mindspore.recompute(block, *args, use_reentrant=True, output_recompute=False, **kwargs)[source]
This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
Recompute function only support block which inherited from Cell object.
- Parameters
- Keyword Arguments
use_reentrant (bool, optional) – This keyword is only valid in PyNative mode. If use_reentrant=True is set, we will implement recomputation through a custom bprop function, which does not support differentiation of complex types such as List/Tuple, If use_reentrant=False is set, we will use the saved_tensors_hook functionality to implement recomputation, which supports differentiation of tensors inside complex types. Default:
True.output_recompute (bool, optional) – This keyword is only valid in PyNative mode. If output_recompute=True is set, we will implement recomputation by saved_tensors_hook functionality by default. The output of this cell or function will not be stored by subsequent operations for backward. when there are two adjacent cells both requiring recomputation (where the output of one cell serves as the input to the other), the recomputation of these two cells will be merged. In this case, the output activation values of the first cell will not be stored. If output_recompute=False, we will not merge adjacent cells. Default:
False.**kwargs – Other arguments.
- Returns
Same as return type of block.
- Raises
TypeError – If block is not Cell object.
- Supported Platforms:
AscendGPUCPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import ops >>> from mindspore import Tensor, recompute >>> class MyCell(nn.Cell): ... def __init__(self): ... super(MyCell, self).__init__(auto_prefix=False) ... self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones') ... self.relu = ops.ReLU() ... ... def construct(self, x): ... y = recompute(self.conv, x) ... return self.relu(y) >>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2) >>> my_net = MyCell() >>> grad = ops.grad(my_net)(inputs) >>> print(grad) [[[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]] [[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]]]