$V(S_{t}) = V(S_{t}) + \alpha (R_{t+1} + \gamma V(S_{t+1}) - V(S_{t}))$

[1]:

from mindspore import ops, Tensor, vmap, jit, grad

value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)
theta = Tensor([0.2, -0.2, 0.1])


[2]:

s_t = Tensor([2., 1., -2.])
r_tp1 = Tensor(2.)
s_tp1 = Tensor([1., 2., 0.])


$\Delta{\theta}=(r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t}))\nabla v_{\theta}(s_{t})$

$L(\theta) = [r_{t+1} + v_{\theta}(s_{t+1}) - v_{\theta}(s_{t})]^{2}$

[3]:

def td_loss(theta, s_tm1, r_t, s_t):
v_t = value_fn(theta, s_t)
target = r_tp1 + value_fn(theta, s_tp1)
return (ops.stop_gradient(target) - v_t) ** 2


td_loss传入grad中，计算td_loss关于theta的梯度，即theta的更新量。

[4]:

td_update = grad(td_loss)
delta_theta = td_update(theta, s_t, r_tp1, s_tp1)
print(delta_theta)

[-4. -8. -0.]


td_update仅根据一个样本，计算td_loss关于参数$${\theta}$$的梯度，我们可以使用vmap对该函数进行矢量化，它会对所有的inputs和outputs添加一个批处理维度。现在，我们给出一批量的输入，并产生一批量的输出，输出批量中的每个输出元素都对应于输入批量中相应的输入元素。

[5]:

batched_s_t = ops.stack([s_t, s_t])
batched_r_tp1 = ops.stack([r_tp1, r_tp1])
batched_s_tp1 = ops.stack([s_tp1, s_tp1])
batched_theta = ops.stack([theta, theta])

batch_theta = ops.stack([theta, theta])
delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

[[-4. -8.  0.]
[-4. -8.  0.]]


[6]:

inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))
delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

[[-4. -8.  0.]
[-4. -8.  0.]]


[7]:

efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)
delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)
print(delta_theta)

[[-4. -8.  0.]
[-4. -8.  0.]]