让BERT瘦下来 MindSpore量化训练极低比特语言模型 TernaryBERT
让BERT瘦下来 MindSpore量化训练极低比特语言模型 TernaryBERT


基于Transformer的预训练模型如BERT在许多自然语言处理任务中都取得了显著的性能。然而,这些模型昂贵的计算和内存都阻碍了它们在资源受限设备上的部署。因此,我们提出了TernaryBERT,它将微调的BERT模型中权值三值化。此外,为了减少低比特导致的精度下降,我们在训练过程中采用了知识蒸馏技术。在GLUE和SQuAD上进行的实验表明,我们提出的TernaryBERT量化方法优于其他的BERT量化方法,甚至可以达到与全精度模型相当的性能,同时将模型缩小了14.9倍。现在TernaryBERT的开源代码已经在MindSpore上首发了。

图1:不同算法的模型尺寸与MNLI-m精度对比。
我们提出的方法(红色方块)优于其他的BERT压缩方法。
图片来源:https://arxiv.org/abs/2009.12812
论文链接:
https://arxiv.org/abs/2009.12812
开源地址:
https://gitee.com/mindspore/mindspore/tree/master

BERT模型由Transformer层构成。标准的Transformer层包括两个主要的子层:多头注意力(MHA)模块和前馈网络(FFN)。
对于第 个Transformer层,假设输入为 ,其中 和 分别是序列长度和隐藏状态大小。假设每层都有 个attention头,头部 由 参数化,其中 。通过query和key的点积计算attention score。
将softmax函数应用于归一化的分数以得到
。 ,
其中 可以是 。多头注意力的输出是:
FFN层由两个线性层组成,分别由 和 参数化,其中 是FFN的intermediate层的神经元数目。将FFN的输入表示为 ,然后将输出计算为:
结合上面两式,第l个Transformer层的前向传播可以写成:
其中 是层归一化。第一个Transformer层的输入是token embedding、segment embedding和position embedding的结合。这里 是输入序列, 、、分别是可学习的word embedding、segment embedding和position embedding。
对于权重量化,我们量化来自所有Transformer层中的权重
、、、、、 ,以及word embedding中的 。除了这些权重外,我们还量化了前向传播中所有线性层的输入和矩阵乘法算子。我们不量化 、 和线性层中的bias,因为它们所涉及的参数可以忽略不计。我们也不量化softmax算子、层归一化和最后一个任务特定层,因为这些算子中包含的参数可以忽略不计,并且量化它们会导致显著的精度下降。

下面我们将讨论下图中权重三值化函数 的选择。
权重三值化在ternary-connect(Z. Lin, M. Courbariaux, R. Memisevic, and Y. Bengio. 2016. Neural networks with few multiplications.)中首创,其中三值化可以取通过2-bit表示的 。通过三值化,将前向过程中的大部分浮点乘法转换为浮点加法,大大减少了计算量和内存。通过添加缩放参数可以获得更好的结果。因此,为了将BERT的权重三值化,我们使用了基于近似的三值化方法TWN(F. Li, B. Zhang, and B. Liu. 2016. Ternary weight networks.),其中三元权重 可由缩放参数 和三元向量 的乘积表示为 。这里 是 中的元素个数。
在第t次训练迭代中,TWN通过最小化全精度权重 与三值化的权重 之间的距离来实现权重的三值化,我们将上述问题定义为如下的优化问题:
设 是一个阈值函数,如果 ,则 ,若 ,则 ,其他情况下 ,其中 是一个正数阈值。设 为元素乘法,上式的最优解满足:
的精确解需要昂贵的排序操作。因此,TWN给出了近似的阈值 。
在TWN的原始论文中,每个卷积层或全连接层都使用一个缩放参数。本文将缩放参数扩展到以下两个粒度:(i)layer-wise三值化,对每个权重矩阵中的所有元素使用一个缩放参数;(ii)raw-wise三值化,对权重矩阵中的每一行使用一个缩放参数。随着缩放参数的增加,raw-wise三值化具有更细的粒度和更小的量化误差。

为了使最昂贵的矩阵乘法运算更快,本文将激活(即所有线性层和矩阵乘法的输入)量化为8-bit。常用的8-bit量化方法有两种:对称和最小-最大8-bit量化。对称8-bit量化的量化值在0的两侧对称分布,而最小-最大8-bit量化的量化值均匀分布在由最小值和最大值确定的范围内。
我们发现BERT中的Transformer层的隐藏层的分布趋于负值。这种偏差在前面的层中更为明显。因此,我们对激活值使用最小-最大8-bit量化,因为它更好地解决了非对称分布。

图2:在SQuAD v1.1上训练的全精度BERT的第1和第6层隐藏层的分布。
图片来源:https://arxiv.org/abs/2009.12812
具体而言,对于激活值 中的一个元素 ,表示 和 ,最小-最大8-bit量化函数为
其中 是缩放参数。我们使用直通估计器反向传播量化后的激活值的梯度。

量化的BERT使用低比特数值来表示模型权重和激活值。因此,与全精度的对应模型相比,它的信息容量相对较低,性能更差。为了解决这一问题,我们结合了知识蒸馏技术来提高量化的BERT的性能。在这个知识蒸馏框架中,量化的BERT作为学生模型,学习去恢复Transformer层和预测层上的全精度的教师模型的行为。

图3:BERT模型的蒸馏感知三值化描述。
图片来源:https://arxiv.org/abs/2009.12812
具体来说,Transformer层 的蒸馏目标包括两部分。第一部分是蒸馏损失,它将全精度教师模型的embedding层和所有Transformer层的输出提取到量化学生模型中,通过均方误差损失(MSE): 。第二部分是从教师模型的attention score中提取知识的蒸馏损失,从每个Transformer层的所有头部 到学生模型的attention score ,即 。因此,Transformer层 的蒸馏公式如下:
除了Transformer层外,我们还在预测层提取知识,使学生模型的logits 通过soft cross-entropy(SCE)损失从教师模型中学习拟合 :
因此,在TernaryBERT训练过程中进行知识提炼的总体目标是:
我们使用对下游任务进行微调的全精度BERT初始化我们的量化模型,并使用数据增广方法(X. Jiao, Y. Yin, L. Shang, X. Jiang, X. Chen, L. Li, F. Wang, and Q. Liu. 2019. Tinybert: Distilling bert for natural language understanding.)来提高性能。整个过程称为蒸馏感知三值化,如算法1所示。

图4:算法1。
图片来源:https://arxiv.org/abs/2009.12812

和BERT量化算法对比的结果
表1显示了GLUE基准的开发集结果。从表1中我们发现:1)对于2-bit权重,由于模型容量的急剧减少,Q-BERT(或Q2BERT)与全精度BERT之间存在很大的差距。TernaryBERT的性能明显优于Q-BERT和Q2BERT,即使word embedding的比特数更少。同时,TerneyBERT以14.9倍更小的尺寸实现了与全精度基线相当的性能。2)当权值的位数增加到8时,所有量化模型的性能都得到了极大的改善,甚至可以与全精度基线相媲美,这表明设置8-8-8对BERT来说并不具有挑战性。我们提出的方法在MNLI和SST-2上都优于Q-BERT,在8个任务中有7个优于Q8BERT。3)TWN和LAT在所有任务上都取得了相似的结果,表明两种三值化方法都具有竞争力。

表1:GLUE基准上量化的BERT和TinyBERT的开发集结果。
我们将Transformer层权重、word embedding和激活的位数缩写为“W-E-A(#位)”。
表格来源:https://arxiv.org/abs/2009.12812
和其他BERT压缩方法对比
从表2可以看出,与量化以外的其他常用的BERT压缩方法相比,本文提出的方法可以获得相似或更好的性能,但要小得多。

表2:在MNLI-m上,TernaryBERT与其他压缩方法的比较。
表格来源:https://arxiv.org/abs/2009.12812

相关训练与推理代码,以及使用方法已经开源在:
https://gitee.com/mindspore/mindspore/tree/master
为了方便大家验证我们的结果以及创新,我们将模型的结构,以及超参数的设置汇总到了相关的代码仓的/script文件夹。src/config.py中存放了配置信息。参数设置以GPU训练脚本train.sh为例:

图5:训练脚本
如果想切换其他的glue数据集,只需要在--task_name的位置将上图中的sts-b更改即可。若想使用自己的数据集。可以参考src/dataset.py中构造数据pipeline的代码。只事先将文本数据转换成需要的输入格式然后封装为tfrecord或者mindrecord格式,就可以使用pipeline进行读取。

图6:构造数据pipeline的代码
TernaryBERT的模型结构的定义和激活伪量化操作放在/src/tinybert_model.py中。用户可以在这里手动插入激活的伪量化结点或更改网络结构。

图7:可以灵活插入激活伪量化结点
src/cell_wrapper.py封装了训练相关的类以及权重的伪量化操作。

图8:权重伪量化操作
最后,MindSpore的model_zoo中存放有TernaryBERT针对MNLI-m、QNLI和STS-B对应的训练脚本。模型均可达到论文中所述精度。

表3:在MNLI-m、QNLI和STS-B上,
通过mindspore实现的TernaryBERT的精度。
MindSpore官方资料
GitHub : https://github.com/mindspore-ai/mindspore
Gitee:https : //gitee.com/mindspore/mindspore
官方QQ群 : 871543426
扫描下方二维码加入MindSpore项目
