mindspore.saved_tensors_hooks
- mindspore.saved_tensors_hooks(pack_hook, unpack_hook)[source]
A context manager used to customize how saved tensors are packed and unpacked.
Certain tensors from the forward pass are stored for use in the backward process. By using this context, users can specify:
How these tensors are packed before saving (pack stage) .
How they are restored when accessed during gradient computation (unpack stage) .
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any Accepts a tensor and returns an arbitrary object that represents the stored form of the tensor.
unpack_hook(packed: Any) -> Tensor Accepts the object returned by
pack_hookand restores the corresponding tensor.
Note
This context manager is currently not supported in Graph and Jit mode.
Warning
Performing in-place modifications on the tensor passed into a pack_hook is not allowed.
To prevent reference cycles, the object returned by pack_hook cannot hold a direct reference to the original tensor.
- Parameters
pack_hook (Callable) – A function that defines how to process a tensor before it is saved during the forward pass.
unpack_hook (Callable) – A function that defines how to recover the tensor when it is needed during the backward computation.
- Supported Platforms:
AscendGPUCPU
Examples
>>> import mindspore as ms >>> from mindspore import ops >>> def pack_hook(x): ... print("packing ", x) ... return x + 1 >>> >>> def unpack_hook(x): ... print("unpacking ", x) ... return x >>> >>> def forward_fn(x, y): ... with ms.saved_tensors_hooks(pack_hook, unpack_hook): ... out = x * y ... print("forward end") ... return out >>> x = ops.ones(2, dtype=ms.float32) >>> y = ops.ones(2, dtype=ms.float32) >>> ms.value_and_grad(forward_fn, grad_position=(0,1))(x, y) packing [1. 1.] packing [1. 1.] forward end unpacking [2. 2.] unpacking [2. 2.]