# 梯度累积

## 梯度累积原理

$Loss(\theta)=\frac{1}{2}\left(h(x^{k})-y^{k}\right)^{2}$

$\theta_{i}=\theta_{i-1}-lr * grad_{i}$

$accumulated=\sum_{i=0}^{N} grad_{i}$

$\theta_{i}=\theta_{i-1}-lr * \sum_{i=0}^{N} grad_{i}$

1. 学习率 learning rate：一定条件下，Batch size越大训练效果越好，梯度累积则模拟了Batch size增大的效果，如果accumulation steps为4，则Batch size增大了4倍，根据经验，使用梯度累积的时候需要把学习率适当放大。

2. 归一化 Batch Norm：accumulation steps为4时进行Batch size模拟放大的效果，与真实Batch size相比，数据的分布其实并不完全相同，4倍Batch size的Batch Norm计算出来的均值和方差与实际数据均值和方差不太相同，因此有些实现中会使用Group Norm来代替Batch Norm。

## 梯度累积实现

[1]:

import mindspore as ms
from mindspore import Tensor, Parameter, ops

@ms.jit_class
class Accumulator():
def __init__(self, optimizer, accumulate_step, clip_norm=1.0):
self.optimizer = optimizer
self.clip_norm = clip_norm
self.zeros = optimizer.parameters.clone(prefix="zeros_", init='zeros')
self.counter = Parameter(Tensor(1, ms.int32), 'counter_')
assert accumulate_step > 0
self.accumulate_step = accumulate_step
self.map = ops.HyperMap()

if self.counter % self.accumulate_step == 0:
# 如果达到累积步数，进行参数优化更新
# 计算步数加一

return True


ms.jit_class为MindSpore即时编译修饰器，可以将普通的Python类作为可编译计算图使用。

[2]:

from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"

def datapipe(path, batch_size):
image_transforms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Normalize(mean=(0.1307,), std=(0.3081,)),
vision.HWC2CHW()
]
label_transform = transforms.TypeCast(ms.int32)

dataset = MnistDataset(path)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset

class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10)
)

def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits

model = Network()

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:06<00:00, 1.67MB/s]
Extracting zip file...


[3]:

accumulate_step = 2

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)
accumulator = Accumulator(optimizer, accumulate_step)

def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
# loss除以累加步数accumulate_step
return loss / accumulate_step


[4]:

grad_fn = value_and_grad(forward_fn, None, model.trainable_params())

@ms.jit
def train_step(data, label):
return loss


[5]:

def train_loop(model, dataset, loss_fn, optimizer):
size = dataset.get_dataset_size()
model.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)

if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

[6]:

def test_loop(model, dataset, loss_fn):
num_batches = dataset.get_dataset_size()
model.set_train(False)
total, test_loss, correct = 0, 0, 0
for data, label in dataset.create_tuple_iterator():
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= total
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


[7]:

train_dataset = datapipe('MNIST_Data/train', 32)
test_dataset = datapipe('MNIST_Data/test', 32)


[8]:

epochs = 3
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(model, train_dataset, loss_fn, optimizer)
test_loop(model, test_dataset, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1.150851  [  0/1875]
loss: 1.149633  [100/1875]
loss: 1.145340  [200/1875]
loss: 1.140591  [300/1875]
loss: 1.134244  [400/1875]
loss: 1.125991  [500/1875]
loss: 1.100611  [600/1875]
loss: 1.051961  [700/1875]
loss: 0.925877  [800/1875]
loss: 0.879966  [900/1875]
loss: 0.750192  [1000/1875]
loss: 0.617844  [1100/1875]
loss: 0.470084  [1200/1875]
loss: 0.560856  [1300/1875]
loss: 0.359766  [1400/1875]
loss: 0.502521  [1500/1875]
loss: 0.299145  [1600/1875]
loss: 0.383266  [1700/1875]
loss: 0.239381  [1800/1875]
Test:
Accuracy: 84.8%, Avg loss: 0.528309

Epoch 2
-------------------------------
loss: 0.390662  [  0/1875]
loss: 0.250778  [100/1875]
loss: 0.570571  [200/1875]
loss: 0.196102  [300/1875]
loss: 0.297634  [400/1875]
loss: 0.192528  [500/1875]
loss: 0.231240  [600/1875]
loss: 0.144425  [700/1875]
loss: 0.113696  [800/1875]
loss: 0.233481  [900/1875]
loss: 0.212078  [1000/1875]
loss: 0.144562  [1100/1875]
loss: 0.220822  [1200/1875]
loss: 0.197890  [1300/1875]
loss: 0.283782  [1400/1875]
loss: 0.219684  [1500/1875]
loss: 0.155505  [1600/1875]
loss: 0.255665  [1700/1875]
loss: 0.155548  [1800/1875]
Test:
Accuracy: 90.1%, Avg loss: 0.340294

Epoch 3
-------------------------------
loss: 0.176077  [  0/1875]
loss: 0.204260  [100/1875]
loss: 0.339903  [200/1875]
loss: 0.221457  [300/1875]
loss: 0.244668  [400/1875]
loss: 0.089163  [500/1875]
loss: 0.159595  [600/1875]
loss: 0.211632  [700/1875]
loss: 0.096592  [800/1875]
loss: 0.081018  [900/1875]
loss: 0.190852  [1000/1875]
loss: 0.139729  [1100/1875]
loss: 0.049344  [1200/1875]
loss: 0.122041  [1300/1875]
loss: 0.198622  [1400/1875]
loss: 0.133956  [1500/1875]
loss: 0.144801  [1600/1875]
loss: 0.076985  [1700/1875]
loss: 0.103241  [1800/1875]
Test:
Accuracy: 92.0%, Avg loss: 0.281193

Done!