Paper:1910.01108.pdf
Code:transformers/examples/distillation at master · huggingface/transformers
核心思想:通过知识蒸馏(在 logits,hidden_states 上计算学生与教师的 loss)训练一个小(主要是层数)模型实现和大模型类似的效果。
What
动机和核心问题
- 成倍增长的计算成本
- 模型不断增长的算力和内存要求可能阻止被广泛使用
本文的研究表明,使用经过知识蒸馏得到的的比预训练的小得多的语言模型,可以在许多下游任务上达到类似的性能,而且可以在移动设备上运行。
模型和算法
知识蒸馏 [Bucila 等,2006; Hinton 等,2015] 是一种压缩技术,训练一个紧凑型模型(学生),以再现较大模型(教师)或模型集合的行为。
损失函数
其中,ti 和 si 分别表示老师和学生的概率估计。
使用了 softmax-temperature:
其中,T 控制输出分布的平滑程度,推理时设置为 1 变为标准的 Softmax,zi 表示类别 i 的分数。
最终的损失函数是 Lce 和 masked language modeling loss Lmlm 的线性组合,另外作者发现添加余弦嵌入损失(Lcos)有利于使学生和教师隐藏状态向量的方向一致。
学生架构
- 具有和老师一般的体系结构。
- 移除了 Token type embedding 和 pooler。
- 层数减少为 1/2:作者调查发现改变隐层维度对计算效率的影响比其他因素的变化要小(比如层数)。
从老师的两层中选择一层来初始化学生。蒸馏应用了Liu et al. [2019] 提出的 BERT 模型训练最佳实践。语料和 Bert 使用的一致。
蒸馏器代码
我们直接看最核心的代码:
1 | # 一个 batch |
无疑,step
就是其核心了,不过在此之前我们先看一下 batch 数据的处理。它 input 的 batch 是一个 tuple,里面包括了 batch_sequence 和 batch_sequence_length,比如其中的一个 batch 结果如下:
1 | # batch_size = 2 |
注意,每次是根据 batch 里的最长序列对其他序列进行 padding 的。
1 | def prepare_batch_mlm(batch): |
这里其实和 Bert 唯一的不同就是 mask 时采用了 token_probs,让选择 mask 时更加关注低频词,进而实现对 mask 的平滑取样(如果按平均分布取样的话,取到的 mask 可能大部分都是重复的高频词)。个人认为这一步还是挺重要的,很 make sense。
现在我们来看 step
:
1 | # --alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm (true) |
简单总结一下:
- 分别计算 teacher 和 student 的 logits 和 hidden_states
- 计算 mask 的 logits,mask 可以选择只计算 masked tokens,也可以选择计算不含 padding 的 input tokens,两者最后用来计算 loss 的 logits 不相同,其中前者的 size 是
(n_tgt, vocab_size)
,后者的 size 是(sum(lenghts), vocab_size)
。计算 Lce(散度),并计算 loss:loss = alpha_ce × Lce - 直接用 student 的 logits 和 lm_labels 计算 Lmlm(交叉熵),并计算累计 loss:loss += alpha_mlm × Lmlm
- 计算 mask 的 hidden_states(最后一层),mask 选择不含 padding 的 input tokens(同上面第二种,也是输入的 attention mask),size 为
(sum(lenghts), hidden_dim)
。计算 Lcos(余弦嵌入),并计算累计 loss:loss += alpha_cos × Lcos
这样,我们就和前面的理论对应起来了,感性地理解,第一个和第三个 loss 是和教师一致的保证,第二个 loss 是自我(Bert)的保证,非常 make sense。
有两个小地方需要注意一下:
- MLM 的 labels 是那些 masked 掉的 token_ids(未 mask 的设为一个负数,论文设置为 -100)
- Cos 的 target 全为 1,size 为
(sum(lengths), )
特点和创新
- 实践证明了可以通过蒸馏成功地训练通用语言模型。
- 利用教师的知识进行初始化。
- 余弦嵌入损失的使用。
How
如何构造数据
根据官方代码文档,准备数据包括两步:
- binarize the data
- count the occurrences of each tokens in the data
开始之前首先需要一个 dump.txt
,每一行是一个序列,每个序列由几个连贯的句子组成。这个无需再述,我们从 OpenNMT 自带的翻译语料中选择部分作为实验材料。
第一步其实就是对每个序列使用 Bert(或其他)的词表进行 One-Hot 编码,每个文本序列对应一个数组,如:
1 | # '[CLS] with the recent plan of ecu 10m, we have slightly enlarged the list of the ngos. [SEP]' |
第二步是因为使用了 XLM 的 masked language modeling loss 需要平滑 mask token 的分布——更加关注出现次数少的词,所以需要统计数据中 token 的出现次数。最终的结果就是词表中每个 token 的出现次数。
如何开始训练
1 | # From 官方 code |
DistilBERT 用的是 Bert 的 Tokenizer,student 和 teacher 分别是两个对应的 MaskedLM。
如何使用结果
与 Bert 一样。
数据和实验
40% smaller, 60% faster, that retains 97% of the language understanding capabilities.
Discussion
相关工作
Task-specific distillation
- 将精调的分类模型 BERT 转移到基于 LSTM 的分类器:Raphael Tang, Yao Lu, Linqing Liu, Lili Mou, Olga Vechtomova, and Jimmy Lin. Distilling task-specific knowledge from bert into simple neural networks. ArXiv, abs/1903.12136, 2019.
- 由 BERT 初始化较小的 Transformer 模型,在 SQuAD 上精调:Debajyoti Chatterjee. Making neural machine reading comprehension faster. ArXiv, abs/1904.00796, 2019.
- 使用原始的预训练目标来训练较小的学生,然后通过蒸馏进行精调:Iulia Turc, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Well-read students learn better: The impact of student initialization on knowledge distillation. ArXiv, abs/1908.08962, 2019.
Multi-distillation
- 多任务:Ze Yang, Linjun Shou, Ming Gong, Wutao Lin, and Daxin Jiang. Model compression with multi-task knowledge distillation for web-scale question answering system. ArXiv, abs/1904.09636, 2019.
- 多语言:Henry Tsai, Jason Riesa, Melvin Johnson, Naveen Arivazhagan, Xin Li, and Amelia Archer. Small and practical bert models for sequence labeling. In EMNLP-IJCNLP, 2019.
Other compression techniques
- 剪枝:Paul Michel, Omer Levy, and Graham Neubig. Are sixteen heads really better than one? In NeurIPS, 2019.
- 量化:Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, and Pritish Narayanan. Deep learning with limited numerical precision. In ICML, 2015.