mindspore.vjp

mindspore.vjp(fn, *inputs, has_aux=False)[source]

Compute the vector-jacobian-product of the given network. vjp matches reverse-mode differentiation.

Parameters
  • fn (Union[Function, Cell]) – The function or net that takes Tensor inputs and returns single Tensor or tuple of Tensors.

  • inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) – The inputs to fn .

  • has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default: False.

Returns

Forward outputs and function to calculate vjp.

  • net_output (Union[Tensor, tuple[Tensor]]) - The output of fn(inputs). Specially, when has_aux is set True, netout is the first output of fn(inputs).

  • vjp_fn (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and type should be the same as netout .

  • aux_value (Union[Tensor, tuple[Tensor]], optional) - When has_aux is True, aux_value will be returned. It means the second to last outputs of fn(inputs). Specially, aux_value does not contribute to gradient.

Raises

TypeErrorinputs or v does not belong to required types.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import vjp
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x**3 + y
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> outputs, vjp_fn = vjp(Net(), x, y)
>>> print(outputs)
[[ 2. 10.]
 [30. 68.]]
>>> gradient = vjp_fn(v)
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00,  1.20000000e+01],
 [ 2.70000000e+01,  4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))
>>> def fn(x, y):
...     return 2 * x + y, y ** 3
>>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
>>> gradient = vjp_fn(v)
>>> print(outputs)
[[ 3.  6.]
 [ 9. 12.]]
>>> print(aux)
[[ 1.  8.]
 [27. 64.]]
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00,  2.00000000e+00],
 [ 2.00000000e+00,  2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))