mindspore.ops.jet

View Source On Gitee
mindspore.ops.jet(fn, primals, series)[source]

This function is designed to calculate the higher order differentiation of given composite function. To figure out first to n-th order differentiations, original inputs and first to n-th order derivative of original inputs must be provided together. Generally, it is recommended to set the values of given first order derivative to 1, while the other to 0, which is like the derivative of origin input with respect to itself.

Note

If primals is tensor of int type, it will be converted to Tensor of float type.

Parameters
  • fn (Union[Cell, function]) – Function to do TaylorOperation.

  • primals (Union[Tensor, tuple[Tensor]]) – The inputs to fn.

  • series (Union[Tensor, tuple[Tensor]]) – The original 1st to nth order derivatives of the input. The index i of the zeroth dimension of the tensor corresponds to the i+1 -th order derivative of the output with respect to the input.

Returns

Tuple(out_primals, out_series)

  • out_primals (Union[Tensor, list[Tensor]]) - The output of fn(primals).

  • out_series (Union[Tensor, list[Tensor]]) - The 1 to i+1 -th order of derivative of output with respect to the inputs.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import nn
>>> mindspore.set_context(mode=mindspore.GRAPH_MODE)
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.sin = mindspore.ops.Sin()
...         self.exp = mindspore.ops.Exp()
...     def construct(self, x):
...         out1 = self.sin(x)
...         out2 = self.exp(out1)
...         return out2
>>> primals = mindspore.tensor([[1, 2], [3, 4]], mindspore.float32)
>>> series = mindspore.tensor([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], mindspore.float32)
>>> net = Net()
>>> out_primals, out_series = mindspore.ops.jet(net, primals, series)
>>> print(out_primals, out_series)
[[2.319777  2.4825778]
 [1.1515628 0.4691642]] [[[ 1.2533808  -1.0331168 ]
  [-1.1400385  -0.3066662 ]]
 [[-1.2748207  -1.8274734 ]
  [ 0.966121    0.55551505]]
 [[-4.0515366   3.6724353 ]
  [ 0.5053504  -0.52061415]]]