mindspore.register_saved_tensors_hooks

查看源文件
mindspore.register_saved_tensors_hooks(pack_hook, unpack_hook)[源代码]

一个静态图模式下的装饰器,用于自定义保存张量(Saved Tensor)的打包(pack)和解包(unpack)方式。

功能上等价于动态图模式的 with mindspore.saved_tensors_hooks(pack_hook, unpack_hook)。 更多详细信息请参考 mindspore.saved_tensors_hooks

说明

  • 该装饰器只支持图模式。

  • pack_hookunpack_hook 必须满足图模式下的语法约束。

样例:

>>> import mindspore as ms
>>> from mindspore import register_saved_tensors_hooks
>>> from mindspore import ops
>>>
>>> def pack_hook(x):
...     print("packing ", x)
...     return x + 1
>>>
>>> def unpack_hook(x):
...     print("unpacking ", x)
...     return x
>>>
>>> @register_saved_tensors_hooks(pack_hook, unpack_hook)
... def forward_fn(x, y):
...     out = x * y
...     return out
>>>
>>> x = ops.ones(2, dtype=ms.float32)
>>> y = ops.ones(2, dtype=ms.float32)
>>> ms.jit(ms.value_and_grad(forward_fn, grad_position=(0,1)))(x, y)
packing
Tensor(shape=[2], dtype=Float32, value=[ 1.00000000e+00  1.00000000e+00])
packing
Tensor(shape=[2], dtype=Float32, value=[ 1.00000000e+00  1.00000000e+00])
unpacking
Tensor(shape=[2], dtype=Float32, value=[ 2.00000000e+00  2.00000000e+00])
unpacking
Tensor(shape=[2], dtype=Float32, value=[ 2.00000000e+00  2.00000000e+00])