# Gradient Derivation [](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/source_en/migration_guide/model_development/gradient.md) ## Automatic Differentiation Both MindSpore and PyTorch provide the automatic differentiation function. After the forward network is defined, automatic backward propagation and gradient update can be implemented through simple interface invoking. However, it should be noted that MindSpore and PyTorch use different logic to build backward graphs. This difference also brings differences in API design.
PyTorch Automatic Differentiation | MindSpore Automatic Differentiation |
```python # torch.autograd: # The backward is cumulative, and the optimizer needs to be cleared after updating. import torch from torch.autograd import Variable x = Variable(torch.ones(2, 2), requires_grad=True) x = x * 2 y = x - 1 y.backward(x) ``` |
```python # ms.grad: # The forward graph as input, backward graph as output. import mindspore as ms from mindspore import nn class GradNetWrtX(nn.Cell): def __init__(self, net): super(GradNetWrtX, self).__init__() self.net = net def construct(self, x, y): gradient_function = ms.grad(self.net) return gradient_function(x, y) ``` |
PyTorch | MindSpore |
```python # Before calling the backward function, # x.grad and y.grad functions are empty. # After backward, x.grad and y.grad represent the # values after derivative calculation, respectively. import torch print("=== tensor.backward ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y print("x.grad before backward", x.grad) print("y.grad before backward", y.grad) z.backward() print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) print("=== torch.autograd.backward ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y torch.autograd.backward(z) print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) ``` |
```python import mindspore print("=== mindspore.grad ===") x = mindspore.Tensor(1.0) y = mindspore.Tensor(2.0) def net(x, y): return x**2+y out = mindspore.grad(net, grad_position=0)(x, y) print("out", out) out1 = mindspore.grad(net, grad_position=1)(x, y) print("out1", out1) ``` |
Outputs: ```text === tensor.backward === x.grad before backward None y.grad before backward None z tensor(3., grad_fn= |
Outputs: ```text === mindspore.grad === out 2.0 out1 1.0 ``` |
PyTorch | MindSpore |
```python # not support multiple outputs import torch print("=== torch.autograd.backward does not support multiple outputs ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y torch.autograd.backward(z) print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) ``` |
```python # support multiple outputs import mindspore print("=== mindspore.grad multiple outputs ===") x = mindspore.Tensor(1.0) y = mindspore.Tensor(2.0) def net(x, y): return x**2+y, x out = mindspore.grad(net, grad_position=0)(x, y) print("out", out) out1 = mindspore.grad(net, grad_position=1)(x, y) print("out1", out) ``` |
Outputs: ```text === torch.autograd.backward does not support multiple outputs === z tensor(3., grad_fn= |
Outputs: ```text === mindspore.grad multiple outputs === out 3.0 out1 3.0 ``` |