# 使用函数变换计算雅可比矩阵和黑塞矩阵

## 雅可比矩阵

$R^{n} \longrightarrow R^{m}$

$\textbf{x} \longmapsto \textbf{f}(\textbf{x})$

$\nabla：F_{n}^{1} \longrightarrow F_{n}^{n}$

$\partial: F_{n}^{m} \longrightarrow F_{n}^{m \times n}$

$\textbf{f} \longmapsto \partial \textbf{f} = (\frac{\partial \textbf{f}}{\partial x_{1}}, \frac{\partial \textbf{f}}{\partial x_{2}}, \dots, \frac{\partial \textbf{f}}{\partial x_{n}})$

$\begin{split}J_{f} = \begin{bmatrix} \frac{\partial f_{1}}{\partial x_{1}} &\frac{\partial f_{1}}{\partial x_{2}} &\dots &\frac{\partial f_{1}}{\partial x_{n}} \\ \frac{\partial f_{2}}{\partial x_{1}} &\frac{\partial f_{2}}{\partial x_{2}} &\dots &\frac{\partial f_{2}}{\partial x_{n}} \\ \vdots &\vdots &\ddots &\vdots \\ \frac{\partial f_{m}}{\partial x_{1}} &\frac{\partial f_{m}}{\partial x_{2}} &\dots &\frac{\partial f_{m}}{\partial x_{n}} \end{bmatrix}\end{split}$

## 计算雅可比矩阵

[1]:

import time
import mindspore
from mindspore import ops
from mindspore import jacrev, jacfwd, vmap, vjp, jvp, grad
import numpy as np

mindspore.set_seed(1)

def forecast(weight, bias, x):
return ops.dense(x, weight, bias).tanh()


[2]:

D = 16
weight = ops.randn(D, D)
bias = ops.randn(D)
x = ops.randn(D)


[3]:

def partial_forecast(x):
return ops.dense(x, weight, bias).tanh()

_, vjp_fn = vjp(partial_forecast, x)

def compute_jac_matrix(unit_vectors):
jacobian_rows = [vjp_fn(vec)[0] for vec in unit_vectors]
return ops.stack(jacobian_rows)

unit_vectors = ops.eye(D)
jacobian = compute_jac_matrix(unit_vectors)
print(jacobian.shape)
print(jacobian[0])

(16, 16)
[-3.2045446e-05 -1.3530695e-05  1.8671712e-05 -9.6547810e-05
5.9755850e-05 -5.1343523e-05  1.3528993e-05 -4.6988782e-05
-4.5517798e-05 -6.1188715e-05 -1.6264191e-04  5.5033437e-05
-4.3497541e-05  2.2357668e-05 -1.3188722e-04 -3.0677278e-05]


compute_jac_matrix中，使用for循环逐行计算的方式计算雅可比矩阵，计算效率并不高。MindSpore提供jacrev来计算雅可比矩阵，jacrev的实现利用了vmapvmap可以消除compute_jac_matrix中的for循环并向量化整个计算过程。jacrev的参数grad_position指定计算输出相对于哪个参数的雅可比矩阵。

[4]:

from mindspore import jacrev
jacrev_jacobian = jacrev(forecast, grad_position=2)(weight, bias, x)
assert np.allclose(jacrev_jacobian.asnumpy(), jacobian.asnumpy())


[5]:

def perf_compution(func, run_times, *args, **kwargs):
start_time = time.perf_counter()
for _ in range(run_times):
func(*args, **kwargs)
end_time = time.perf_counter()
cost_time = (end_time - start_time) * 1000000
return cost_time

run_times = 500
xp = x.copy()
compute_jac_matrix_cost_time = perf_compution(compute_jac_matrix, run_times, xp)
jacrev_cost_time = perf_compution(jac_fn, run_times, weight, bias, x)
print(f"compute_jac_matrix run {run_times} times, cost time {compute_jac_matrix_cost_time} microseconds.")
print(f"jacrev run {run_times} times, cost time {jacrev_cost_time} microseconds.")

compute_jac_matrix run 500 times, cost time 12942823.04868102 microseconds.
jacrev run 500 times, cost time 909309.7001314163 microseconds.


[6]:

def perf_cmp(first, first_descriptor, second, second_descriptor):
faster = second
slower = first
gain = (slower - faster) / slower
if gain < 0:
gain *= -1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor}. ")

perf_cmp(compute_jac_matrix_cost_time, "for loop", jacrev_cost_time, "jacrev")

 Performance delta: 92.9744 percent improvement with jacrev.


[7]:

jacrev_weight, jacrev_bias = jacrev(forecast, grad_position=(0, 1))(weight, bias, x)
print(jacrev_weight.shape)
print(jacrev_bias.shape)

(16, 16, 16)
(16, 16)


## 反向模式计算雅可比矩阵 vs 前向模式计算雅可比矩阵

MindSpore提供了两个API来计算雅可比矩阵：分别是jacrevjacfwd

• jacrev：使用反向模式自动微分。

• jacfwd：使用前向模式自动微分。

jacfwdjacrev可以相互替换，但是它们在不同的场景下，性能表现不同。

## 黑塞矩阵

$\nabla \circ \partial: F_{n}^{1} \longrightarrow F_{n}^{n} \longrightarrow F_{n \times n}^{n}$

$f \longmapsto \nabla f \longmapsto J_{\nabla f}$

$\begin{split}H_{f} = \begin{bmatrix} \frac{\partial (\nabla _{1}f)}{\partial x_{1}} &\frac{\partial (\nabla _{1}f)}{\partial x_{2}} &\dots &\frac{\partial (\nabla _{1}f)}{\partial x_{n}} \\ \frac{\partial (\nabla _{2}f)}{\partial x_{1}} &\frac{\partial (\nabla _{2}f)}{\partial x_{2}} &\dots &\frac{\partial (\nabla _{2}f)}{\partial x_{n}} \\ \vdots &\vdots &\ddots &\vdots \\ \frac{\partial (\nabla _{n}f)}{\partial x_{1}} &\frac{\partial (\nabla _{n}f)}{\partial x_{2}} &\dots &\frac{\partial (\nabla _{n}f)}{\partial x_{n}} \end{bmatrix} = \begin{bmatrix} \frac{\partial ^2 f}{\partial x_{1}^{2}} &\frac{\partial ^2 f}{\partial x_{2} \partial x_{1}} &\dots &\frac{\partial ^2 f}{\partial x_{n} \partial x_{1}} \\ \frac{\partial ^2 f}{\partial x_{1} \partial x_{2}} &\frac{\partial ^2 f}{\partial x_{2}^{2}} &\dots &\frac{\partial ^2 f}{\partial x_{n} \partial x_{2}} \\ \vdots &\vdots &\ddots &\vdots \\ \frac{\partial ^2 f}{\partial x_{1} \partial x_{n}} &\frac{\partial ^2 f}{\partial x_{2} \partial x_{n}} &\dots &\frac{\partial ^2 f}{\partial x_{n}^{2}} \end{bmatrix}\end{split}$

## 计算黑塞矩阵

[8]:

Din = 32
Dout = 16
weight = ops.randn(Dout, Din)
bias = ops.randn(Dout)
x = ops.randn(Din)

np.allclose(hess1.asnumpy(), hess2.asnumpy())
np.allclose(hess2.asnumpy(), hess3.asnumpy())
np.allclose(hess3.asnumpy(), hess4.asnumpy())

[8]:

True


## 计算批量雅可比矩阵和批量黑塞矩阵

[9]:

batch_size = 64
Din = 31
Dout = 33

weight = ops.randn(Dout, Din)
bias = ops.randn(Dout)
x = ops.randn(batch_size, Din)

compute_batch_jacobian = vmap(jacrev(forecast, grad_position=2), in_axes=(None, None, 0))
batch_jacobian = compute_batch_jacobian(weight, bias, x)
print(batch_jacobian.shape)

(64, 33, 31)


[10]:

hessian = jacrev(jacrev(forecast, grad_position=2), grad_position=2)
compute_batch_hessian = vmap(hessian, in_axes=(None, None, 0))
batch_hessian = compute_batch_hessian(weight, bias, x)
print(batch_hessian.shape)

(64, 33, 31, 31)


## 计算黑塞-向量积

• 将反向模式自动微分与反向模式自动微分组合。

• 将反向模式自动微分与前向模式自动微分组合。

[11]:

def hvp_revfwd(f, inputs, vector):

def f(x):
return x.sin().sum()

inputs = ops.randn(128)
vector = ops.randn(128)

result_hvp_revfwd = hvp_revfwd(f, inputs, vector)
print(result_hvp_revfwd.shape)

(128,)


[12]:

def hvp_revrev(f, inputs, vector):
return vjp_fn(*vector)

result_hvp_revrev = hvp_revrev(f, (inputs,), (vector,))
print(result_hvp_revrev[0].shape)

(128,)


[13]:

assert np.allclose(result_hvp_revfwd.asnumpy(), result_hvp_revrev[0].asnumpy())