mindspore.recompute
- mindspore.recompute(block, *args, **kwargs)[source]
This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
Recompute function only support block which inherited from Cell object.
- Parameters
- Returns
Same as return type of block.
- Raises
TypeError – If block is not Cell object.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import ops >>> from mindspore import Tensor, recompute >>> class MyCell(nn.Cell): ... def __init__(self): ... super(MyCell, self).__init__(auto_prefix=False) ... self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones') ... self.relu = ops.ReLU() ... ... def construct(self, x): ... y = recompute(self.conv, x) ... return self.relu(y) >>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2) >>> my_net = MyCell() >>> grad = ops.grad(my_net)(inputs) >>> print(grad) [[[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]] [[[2. 4.] [4. 8.]] [[2. 4.] [4. 8.]]]]