mindspore.ops.value_and_grad

mindspore.ops.value_and_grad(fn, grad_position=0, weights=None, has_aux=False)[source]

A wrapper function to generate the function to calculate forward output and gradient for the input function.

As for gradient, three typical cases are included:

  1. gradient with respect to inputs. In this case, grad_position is not None while weights is None.

  2. gradient with respect to weights. In this case, grad_position is None while weights is not None.

  3. gradient with respect to inputs and weights. In this case, grad_position and weights are not None.

Parameters
  • fn (Union(Cell, function)) – Function to do GradOperation.

  • grad_position (Union(NoneType, int, tuple[int])) – If int, get the gradient with respect to single input. If tuple, get the gradients with respect to selected inputs. ‘grad_position’ begins with 0. If None, none derivative of any input will be solved, and in this case, weights is required. Default: 0.

  • weights (Union(ParameterTuple, Parameter, list(Parameter))) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params(). Default: None.

  • has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Specially, this is an experimental feature and is subjected to change. Default: False.

Returns

Function, returns the gradient function to calculate forward output and gradient for the input function or cell. For example, as for out1, out2 = fn(*args) , gradient function will return outputs like ((out1, out2), gradient) . When has_aux is set True, only out1 contributes to the differentiation.

Raises
  • ValueError – If both grad_position and weights are None.

  • TypeError – If type of Args does not belong to required ones.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops, nn
>>> from mindspore.ops import value_and_grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
...     def construct(self, x, y, z):
...         return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3]), mindspore.float32)
>>> z = Tensor([0, 3]), mindspore.float32)
>>> net = Net()
>>> grad_fn = value_and_grad(net, grad_position=1)
>>> output, inputs_gradient = grad_fn(x, y, z)
>>> print(output)
[ -0.  18.]
>>> print(inputs_gradient)
[0, 6.]
>>>
>>> # Function object to be differentiated
>>> def fn(x, y, z):
...     res = x * ops.exp(y) * ops.pow(z, 2)
...     return res, z
>>> x = Tensor(np.array([3, 3]).astype(np.float32))
>>> y = Tensor(np.array([0, 0]).astype(np.float32))
>>> z = Tensor(np.array([5, 5]).astype(np.float32))
>>> output, inputs_gradient = value_and_grad(fn, grad_position=(1, 2), weights=None, has_aux=True)(x, y, z)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00,  5.00000000e+00]))
>>> print(inputs_gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01,  3.00000000e+01]))
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
...     logits = net(inputs)
...     loss = loss_fn(logits, labels)
...     return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>>
>>> # Case 1: gradient with respect to inputs.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True)
>>> (loss, logits), inputs_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>>
>>> # Case 2: gradient with respect to weights.
>>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> (loss, logits), params_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(len(weights), len(params_gradient))
2 2
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>> print(len(weights), len(params_gradient))
2 2