mindspore.saved_tensors_hooks
- class mindspore.saved_tensors_hooks(pack_hook, unpack_hook)[source]
A context manager used to customize how saved tensors are packed and unpacked.
Certain tensors from the forward pass are stored for use in the backward process. By using this context, users can specify:
How these tensors are packed before saving (pack stage) .
How they are restored when accessed during gradient computation (unpack stage) .
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any: Accepts a tensor and returns an arbitrary object that represents the stored form of the tensor.
unpack_hook(packed: Any) -> Tensor: Accepts the object returned by pack_hook and restores the corresponding tensor.
Note
This context manager is currently not supported in Graph and Jit mode.
Warning
To prevent undefined behavior, in-place modification of the original tensor passed to the pack_hook will throw an exception.
To prevent reference cycles, the object returned by pack_hook cannot hold a direct reference to the original tensor.
- Parameters
pack_hook (Callable) – A function that defines how to process a tensor before it is saved during the forward pass.
unpack_hook (Callable) – A function that defines how to recover the tensor when it is needed during the backward computation.
- Supported Platforms:
AscendGPUCPU
Examples
>>> import mindspore as ms >>> from mindspore import ops >>> def pack_hook(x): ... print("packing ", x) ... return x + 1 >>> >>> def unpack_hook(x): ... print("unpacking ", x) ... return x >>> >>> def forward_fn(x, y): ... with ms.saved_tensors_hooks(pack_hook, unpack_hook): ... out = x * y ... print("forward end") ... return out >>> x = ops.ones(2, dtype=ms.float32) >>> y = ops.ones(2, dtype=ms.float32) >>> ms.value_and_grad(forward_fn, grad_position=(0,1))(x, y) packing [1. 1.] packing [1. 1.] forward end unpacking [2. 2.] unpacking [2. 2.]