mindspore.recompute

View Source On Gitee
mindspore.recompute(block, *args, **kwargs)[source]

This function is used to reduce memory, when run block, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.

Note

Recompute function only support block which inherited from Cell object.

Parameters
  • block (Cell) – Block to be recompute.

  • args (tuple) – Inputs for block object to run forward pass.

  • kwargs (dict) – Optional input for recompute function.

Returns

Same as return type of block.

Raises

TypeError – If block is not Cell object.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> from mindspore import Tensor, recompute
>>> class MyCell(nn.Cell):
...     def __init__(self):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.conv = nn.Conv2d(2, 2, 2, has_bias=False, weight_init='ones')
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = recompute(self.conv, x)
...         return self.relu(y)
>>> inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2)
>>> my_net = MyCell()
>>> grad = ops.grad(my_net)(inputs)
>>> print(grad)
[[[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]
 [[[2. 4.]
   [4. 8.]]
  [[2. 4.]
   [4. 8.]]]]