mindspore.ops.jet

View Source On AtomGit
mindspore.ops.jet(fn, primals, series)[source]

This function is designed to calculate the higher order differentiation of the given composite function. To figure out first to n-th order differentiations, original inputs and first to n-th order derivatives of original inputs must be provided together. Generally, it is recommended to set the values of the given first-order derivative to 1, while setting the other derivatives to 0, which is like the derivative of original input with respect to itself.

Note

If primals is tensor of int type, it will be converted to Tensor of float type.

Parameters
  • fn (Union[Cell, function]) – Function to do TaylorOperation.

  • primals (Union[Tensor, tuple[Tensor]]) – The inputs to fn.

  • series (Union[Tensor, tuple[Tensor]]) – The original 1st to nth order derivatives of the input. The index i of the zeroth dimension of the tensor corresponds to the i+1 -th order derivative of the output with respect to the input.

Returns

Tuple(out_primals, out_series)

  • out_primals (Union[Tensor, list[Tensor]]) - The output of fn(primals).

  • out_series (Union[Tensor, list[Tensor]]) - The 1 to i+1 -th order of derivative of output with respect to the inputs.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import nn
>>> mindspore.set_context(mode=mindspore.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.sin = mindspore.ops.Sin()
...         self.exp = mindspore.ops.Exp()
...     def construct(self, x):
...         out1 = self.sin(x)
...         out2 = self.exp(out1)
...         return out2
>>> primals = mindspore.tensor([[1, 2], [3, 4]], mindspore.float32)
>>> series = mindspore.tensor([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], mindspore.float32)
>>> net = Net()
>>> out_primals, out_series = mindspore.ops.jet(net, primals, series)
>>> print(out_primals, out_series)
[[2.319777  2.4825778]
 [1.1515628 0.4691642]] [[[ 1.2533808  -1.0331168 ]
  [-1.1400385  -0.3066662 ]]
 [[-1.2748207  -1.8274734 ]
  [ 0.966121    0.55551505]]
 [[-4.0515366   3.6724353 ]
  [ 0.5053504  -0.52061415]]]