mindspore.ops.GradOperation

class mindspore.ops.GradOperation(get_all=False, get_by_list=False, sens_param=False)[source]

A higher-order function which is used to generate the gradient function for the input function.

The gradient function generated by GradOperation higher-order function can be customized by construction arguments.

Given an input function net = Net() that takes x and y as inputs, and has a parameter z, see Net in Examples.

To generate a gradient function that returns gradients with respect to the first input (see GradNetWrtX in Examples).

  1. Construct a GradOperation higher-order function with default arguments: grad_op = GradOperation().

  2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  3. Call the gradient function with input function’s inputs to get the gradients with respect to the first input: grad_op(net)(x, y).

To generate a gradient function that returns gradients with respect to all inputs (see GradNetWrtXY in Examples).

  1. Construct a GradOperation higher-order function with get_all=True which indicates getting gradients with respect to all inputs, they are x and y in example function Net(): grad_op = GradOperation(get_all=True).

  2. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  3. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs: gradient_function(x, y).

To generate a gradient function that returns gradients with respect to given parameters (see GradNetWithWrtParams in Examples).

  1. Construct a GradOperation higher-order function with get_by_list=True: grad_op = GradOperation(get_by_list=True).

  2. Construct a ParameterTuple that will be passed to the input function when constructing GradOperation higher-order function, it will be used as a parameter filter that determine which gradient to return: params = ParameterTuple(net.trainable_params()).

  3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

  4. Call the gradient function with input function’s inputs to get the gradients with respect to given parameters: gradient_function(x, y).

To generate a gradient function that returns gradients with respect to all inputs and given parameters in the format of ((dx, dy), (dz))(see GradNetWrtInputsAndParams in Examples).

  1. Construct a GradOperation higher-order function with get_all=True and get_by_list=True: grad_op = GradOperation(get_all=True, get_by_list=True).

  2. Construct a ParameterTuple that will be passed along input function when constructing GradOperation higher-order function: params = ParameterTuple(net.trainable_params()).

  3. Call it with input function and params as arguments to get the gradient function: gradient_function = grad_op(net, params).

  4. Call the gradient function with input function’s inputs to get the gradients with respect to all inputs and given parameters: gradient_function(x, y).

We can configure the sensitivity(gradient with respect to output) by setting sens_param as True and passing an extra sensitivity input to the gradient function, the sensitivity input should has the same shape and type with input function’s output(see GradNetWrtXYWithSensParam in Examples).

  1. Construct a GradOperation higher-order function with get_all=True and sens_param=True: grad_op = GradOperation(get_all=True, sens_param=True).

  2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32)).

  3. Call it with input function as argument to get the gradient function: gradient_function = grad_op(net).

  4. Call the gradient function with input function’s inputs and sens_param to get the gradients with respect to all inputs: gradient_function(x, y, grad_wrt_output).

Parameters
  • get_all (bool) – If True, get all the gradients with respect to inputs. Default: False.

  • get_by_list (bool) – If True, get all the gradients with respect to Parameter variables. If get_all and get_by_list are both False, get the gradient with respect to first input. If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables at the same time in the form of ((gradients with respect to inputs), (gradients with respect to parameters)). Default: False.

  • sens_param (bool) – Whether to append sensitivity (gradient with respect to output) as input. If sens_param is False, a ‘ones_like(outputs)’ sensitivity will be attached automatically. Default: False. If sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through the location parameter or key-value pair parameter. If the value is transferred through the key-value pair parameter, the key must be sens.

Returns

The higher-order function which takes a function as argument and returns gradient function for it.

Raises

TypeError – If get_all, get_by_list or sens_param is not a bool.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore.common import ParameterTuple
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.matmul = P.MatMul()
...         self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
...     def construct(self, x, y):
...         x = x * self.z
...         out = self.matmul(x, y)
...         return out
...
>>> class GradNetWrtX(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtX, self).__init__()
...         self.net = net
...         self.grad_op = GradOperation()
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y)
...
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> output = GradNetWrtX(Net())(x, y)
>>> print(output)
[[1.4100001 1.5999999 6.6      ]
 [1.4100001 1.5999999 6.6      ]]
>>>
>>> class GradNetWrtXY(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtXY, self).__init__()
...         self.net = net
...         self.grad_op = GradOperation(get_all=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtXY(Net())(x, y)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 4.50999975e+00,  2.70000005e+00,  3.60000014e+00],
 [ 4.50999975e+00,  2.70000005e+00,  3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 2.59999990e+00,  2.59999990e+00,  2.59999990e+00],
 [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
 [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]]))
>>>
>>> class GradNetWrtXYWithSensParam(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtXYWithSensParam, self).__init__()
...         self.net = net
...         self.grad_op = GradOperation(get_all=True, sens_param=True)
...         self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net)
...         return gradient_function(x, y, self.grad_wrt_output)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtXYWithSensParam(Net())(x, y)
>>> print(output)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 2.21099997e+00,  5.09999990e-01,  1.49000001e+00],
 [ 5.58799982e+00,  2.68000007e+00,  4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.51999998e+00,  2.81999993e+00,  2.14000010e+00],
 [ 1.09999990e+00,  2.04999971e+00,  1.54999995e+00],
 [ 9.00000036e-01,  1.54999995e+00,  1.25000000e+00]]))
>>>
>>> class GradNetWithWrtParams(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWithWrtParams, self).__init__()
...         self.net = net
...         self.params = ParameterTuple(net.trainable_params())
...         self.grad_op = GradOperation(get_by_list=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net, self.params)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWithWrtParams(Net())(x, y)
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)
>>>
>>> class GradNetWrtInputsAndParams(nn.Cell):
...     def __init__(self, net):
...         super(GradNetWrtInputsAndParams, self).__init__()
...         self.net = net
...         self.params = ParameterTuple(net.trainable_params())
...         self.grad_op = GradOperation(get_all=True, get_by_list=True)
...     def construct(self, x, y):
...         gradient_function = self.grad_op(self.net, self.params)
...         return gradient_function(x, y)
>>>
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
>>> output = GradNetWrtInputsAndParams(Net())(x, y)
>>> print(output)
((Tensor(shape=[2, 3], dtype=Float32, value=
[[ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00],
 [ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 6.00000024e-01,  6.00000024e-01,  6.00000024e-01],
 [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
 [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value=
 [ 1.29020004e+01]),))