mindspore.Tensor.register_hook

View Source On Gitee
Tensor.register_hook(hook_fn)[source]

Registers a backward hook for tensor.

Note

  • The register_backward_hook(hook_fn) does not work in graph mode or functions decorated with ‘jit’.

  • The ‘hook_fn’ must be defined as the following code. grad is the gradient passed to the tensor, which may be modified by returning a new output gradient.

  • The ‘hook_fn’ should have the following signature: hook_fn(grad) -> New output gradient, but can not return None or not set return value.

Parameters

hook_fn (function) – Python function. Tensor backward hook function.

Returns

A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .

Raises

TypeError – If the hook_fn is not a function of python.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     return grad * 2
...
>>> def hook_test(x, y):
...     z = x * y
...     z.register_hook(hook_fn)
...     z = z * y
...     return z
...
>>> ms_grad = ms.grad(hook_test, grad_position=(0,1))
>>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32))
>>> print(output)
(Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))