# Gradient Derivation [](https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/source_en/migration_guide/model_development/gradient.md) ## Automatic Differentiation Both MindSpore and PyTorch provide the automatic differentiation function. After the forward network is defined, automatic backward propagation and gradient update can be implemented through simple interface invoking. However, it should be noted that MindSpore and PyTorch use different logic to build backward graphs. This difference also brings differences in API design.
| PyTorch Automatic Differentiation | MindSpore Automatic Differentiation |
```python
# torch.autograd:
# The backward is cumulative, and the optimizer needs to be cleared after updating.
import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2),
requires_grad=True)
x = x * 2
y = x - 1
y.backward(x)
```
|
```python
# ms.grad:
# The forward graph as input, backward graph as output.
import mindspore as ms
from mindspore import nn
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
def construct(self, x, y):
gradient_function = ms.grad(self.net)
return gradient_function(x, y)
```
|
| PyTorch | MindSpore |
```python
# Before calling the backward function,
# x.grad and y.grad functions are empty.
# After backward, x.grad and y.grad represent the
# values after derivative calculation, respectively.
import torch
print("=== tensor.backward ===")
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
print("x.grad before backward", x.grad)
print("y.grad before backward", y.grad)
z.backward()
print("z", z)
print("x.grad", x.grad)
print("y.grad", y.grad)
print("=== torch.autograd.backward ===")
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
torch.autograd.backward(z)
print("z", z)
print("x.grad", x.grad)
print("y.grad", y.grad)
```
|
```python
import mindspore
print("=== mindspore.grad ===")
x = mindspore.Tensor(1.0)
y = mindspore.Tensor(2.0)
def net(x, y):
return x**2+y
out = mindspore.grad(net, grad_position=0)(x, y)
print("out", out)
out1 = mindspore.grad(net, grad_position=1)(x, y)
print("out1", out1)
```
|
Outputs: ```text === tensor.backward === x.grad before backward None y.grad before backward None z tensor(3., grad_fn= |
Outputs: ```text === mindspore.grad === out 2.0 out1 1.0 ``` |
| PyTorch | MindSpore |
```python
# not support multiple outputs
import torch
print("=== torch.autograd.backward does not support multiple outputs ===")
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
torch.autograd.backward(z)
print("z", z)
print("x.grad", x.grad)
print("y.grad", y.grad)
```
|
```python
# support multiple outputs
import mindspore
print("=== mindspore.grad multiple outputs ===")
x = mindspore.Tensor(1.0)
y = mindspore.Tensor(2.0)
def net(x, y):
return x**2+y, x
out = mindspore.grad(net, grad_position=0)(x, y)
print("out", out)
out1 = mindspore.grad(net, grad_position=1)(x, y)
print("out1", out)
```
|
Outputs: ```text === torch.autograd.backward does not support multiple outputs === z tensor(3., grad_fn= |
Outputs: ```text === mindspore.grad multiple outputs === out 3.0 out1 3.0 ``` |