mindspore.saved_tensors_hooks

查看源文件
class mindspore.saved_tensors_hooks(pack_hook, unpack_hook)[源代码]

一个上下文管理器,用于自定义保存张量(Saved Tensor)的打包(pack)和解包(unpack)方式。

在前向计算中,某些张量会被保存,以供反向传播时使用。通过使用该上下文,用户可以指定:

  • 在保存前如何处理这些张量(打包阶段)。

  • 在反向访问时如何恢复这些张量(解包阶段)。

打包和解包函数应符合以下签名:

  • pack_hook(tensor: Tensor) -> Any: 接收一个张量,并返回任意对象,用于表示该张量在存储阶段的形式。

  • unpack_hook(packed: Any) -> Tensor: 接收上述返回值,并恢复出对应的张量。

说明

当前该上下文管理器在Graph模式与Jit模式下不支持。

警告

  • 不允许在 pack_hook 中对传入的原始张量进行原地(in-place)修改。

  • 为避免产生循环引用, pack_hook 返回的对象不能直接持有原始张量的引用。

参数:
  • pack_hook (Callable) - 定义前向计算保存张量时的处理方法。

  • unpack_hook (Callable) - 定义反向计算恢复张量时的处理方法。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> from mindspore import ops
>>> def pack_hook(x):
...     print("packing ", x)
...     return x + 1
>>>
>>> def unpack_hook(x):
...     print("unpacking ", x)
...     return x
>>>
>>> def forward_fn(x, y):
...     with ms.saved_tensors_hooks(pack_hook, unpack_hook):
...         out = x * y
...     print("forward end")
...     return out
>>> x = ops.ones(2, dtype=ms.float32)
>>> y = ops.ones(2, dtype=ms.float32)
>>> ms.value_and_grad(forward_fn, grad_position=(0,1))(x, y)
packing [1. 1.]
packing [1. 1.]
forward end
unpacking [2. 2.]
unpacking [2. 2.]