NLP中知识蒸馏
NLP中的知識蒸餾
一、什么是知識蒸餾
知識蒸餾一個重要目的是讓學生模型學習到老師模型的泛化能力,讓輕量級的學生模型也可以具備重量級老師模型的幾乎同樣的能力。
一個很高效的蒸餾方法就是使用老師網絡softmax層輸出的類別概率來作為軟標簽,和學生網絡的softmax輸出做交叉熵。
傳統訓練方法是硬標簽,正類是1,其他所有負類都是0。但知識蒸餾的訓練過程過程是用老師模型的類別概率作為軟標簽。
二、為什么需要知識蒸餾
大模型雖然效果很好,但模型較重推理速度太慢無法瞞足工業要求,而小模型輕,推理速度快,但是直接使用數據訓練效果較差,知識蒸餾就是想讓小模型在擁有較快的推理速度下,也具備大模型的泛化能力。
三、知識蒸餾中的SoftMax
原始的softmax:
qi=exp(zi)∑jexp(zj)q_i = {\frac{exp(z_i)}{{\sum_{j}{exp(z_j)}}}}qi?=∑j?exp(zj?)exp(zi?)?
上述有說到,知識蒸餾是student模型學習tearch模型的軟標簽,但是如果
所以對softmax加了溫度:
qi=exp(zi/T)∑jexp(zj/T)q_i = {\frac{exp(z_i/T)}{{\sum_{j}{exp(z_j/T)}}}}qi?=∑j?exp(zj?/T)exp(zi?/T)?
根據公式可以看出,讓T越大時,softmax輸出值越平滑,輸出值得熵越大,會放大負標簽攜帶的信息,模型會相對校對的關注負標簽,能夠充分的學習。 一般來講T會大于1;
四、如何選擇溫度:
說白了溫度的高低改變的是學生網絡對負標簽的關注程度
- 溫度較低時,負類別攜帶的信息會被相對減少,對負類別的關注較少,負類別的概率越低,關注越少。
- 溫度較高時,負類別的概率值會相對增大,負類別攜帶的信息會被相對地放大,學生網絡會更多關注到負標簽。
實際上,負類別中包含一定的信息,尤其是那些概率值較高的負類別。 但由于老師網絡的負類別可能會有噪聲,并且負類別的概率值越低,其信息就越不可靠。因此溫度的選取比較看經驗,本質上就是在下面兩件事之中取舍
總的來說,溫度的選擇和學生網絡的大小有關,學生網絡參數量比較小的時候,相對比較低的溫度就可以了,因為參數量小的模型不能捕獲所有知識,所以可以適當忽略掉一些負標簽的信息。
五、如何蒸餾、LOSS是什么樣的
第一步是訓練老師網絡;第二步是蒸餾老師網絡的知識到學生網絡。
高溫蒸餾過程的目標函數由distill loss(對應軟標簽)和student loss(對應硬標簽)加權得到。
L=αLsoft+βLhardL = \alpha L_{soft} + \beta L_{hard} L=αLsoft?+βLhard?
distill loss(對應軟標簽) :
是老師模型softmax經過高溫后輸出的概率分布和學生網絡在同等溫度下的概率分布做交叉熵, 軟標簽 loss:
Lsoft=?∑j=1npjTlog(qjT),其中piT=exp(vi/T)∑k=1nexp(vk/T),qiT=exp(vi/T)∑k=1nexp(vk/T)L_{soft} = -{\sum_{j=1}^n {p_j^T log(q_j^T)}},其中 p_i^T = {\frac{exp(v_i/T)}{{\sum_{k=1}^n exp(v_k/T)}}}, q_i^T = {\frac{exp(v_i/T)}{\sum_{k=1}^n {exp(v_k/T)}}}Lsoft?=?j=1∑n?pjT?log(qjT?),其中piT?=∑k=1n?exp(vk?/T)exp(vi?/T)?,qiT?=∑k=1n?exp(vk?/T)exp(vi?/T)?
student loss(對應硬標簽) :
是學生網絡在溫度為1下的概率分布和真實標簽做交叉熵,硬標簽 loss:
Lhard=?∑jncj1log(qj1),其中qj1=exp(vj)∑knexp(vk)L_{hard} = -{\sum_j^n {c_j^1 log(q_j^1)}}, 其中 q_j^1 = {\frac{exp(v_j)}{\sum_{k}^n {exp(v_k)}}}Lhard?=?j∑n?cj1?log(qj1?),其中qj1?=∑kn?exp(vk?)exp(vj?)?
六、項目開展和算法調優過程
損失函數的比較和選擇
- 交叉熵損失(CrossEntropyLoss):基于softmax-T計算損失。其中softmax-T上述有過介紹,不在過多贅述。
- 均方差損失(MESLoss):基于logits直接計算。
在我的實驗中,兩者之間的訓練結果并無太大差異,反而MSELoss計算方法獲得的結果更優。 (其實是近MSELoss,但大多數實驗者直接用MSELoss替代)基本類似。
使用MSELoss的另一個好處是,避免了超參數T的使用。 超參數T的使用還會影響soft-loss和hard-loss的比重,雖然理論上需要給soft-loss乘以 T2T^2T2 ,讓彼此的權重在同一個數量級上。
對于知識蒸餾建議使用MSELoss,而非使用原本的softmax-T-loss(Hinton,2014),能達到更好的效果,理論和實驗都有證明。 - 項目中的具體做法:
- 這里我使用了一種soft-label的方法。是將teacher模型的logits表示經過softmax后,與one-hot表征的實際label進行相加, 注意這里引入相加的權重alpha,實驗做好的值為0.5,alpha越大越依賴教師模型的logits。然后得到一個新的label表示。如teacher-logit-softmax = [0.2,0.7,0.1],實際標簽one-hot = [0,1,0],alpha = 0.5,那么最后的label = [0.1,0.85,0.05],然后用這個label和student計算獲得的logits進行MSELoss計算,求導。這種方法獲得的結果和直接用MSELoss計算后,然后使用alpha權重相加結果類似,但好處是少了一次MSELoss的計算過程,在訓練時,訓練速度更快。
Teacher模型
因為是中文NLP任務,對于teacher模型選擇的標準是,盡量好,盡量優秀,甚至可以使用集成學習的方法獲得最優結果。
項目中使用了中文的Roberta-base模型作為teacher模型(已經對下游NER實體進行了Finetune,精度F1 = 94.67%),具體參數:Epoch=3,max_sequence_length = 256,batch_size = 32,model_size = 42.2M。
Student模型
前后使用多種Student模型,選取的條件是速度滿足當前模型預測的速度要求,(不做蒸餾前,純訓練)精度越高越好。
將BERT模型蒸餾至TextCNN 和BiLSTM等小模型上,精度下降3%,速度提升400倍。注意文章使用了word2vec詞向量,并非完全從頭訓練,具體細節可看論文和代碼。
蒸餾學習的Student模型分別使用了ALBert-Base,ALBert-Tiny和ELECTRA-Small這三種模型,模型使用的alpha = 0.5,使用的是MSELoss的方法,具體結果:
| Roberta-Base | 0.94675 | — | 412M | 95.35s |
| AlBert-base | 0.810 | 0.909 | 42.2M | 27.3 |
| ALBert-Tiny | 0.612 | 0.824 | 16.3M | 9.3s |
| ELECTRA-small | 0.9124 | 0.9267 | 49.4M | 24.8s |
蒸餾中的數據增強
使用訓練好的teacher模型對數據打標,形成偽標簽,再訓練student模型,即使部分case teacher沒有標對,也沒有很大的關系,目的就是讓student更像teacher,本來badcase就很小,對訓練影響度有限,但是偽標簽數據不易過度,以免真正影響效果。
Batch-size和max-sequence-length的使用:
多步蒸餾到超小模型
如上述實驗中的獎RoBerta-Base模型內容蒸餾到ALBert-Tiny,模型的size差異大約在30倍,如果直接蒸餾,效果會不好。精度大約只能達到82.4%。這里可以借鑒miniLM(Ref-9)的一種操作Trick,間接蒸餾。具體做法是先將大模型(如:RoBerta-Base,94.7%)里的知識蒸餾到一個中(過渡)模型(如:ELECTRA-base,92.1%),然后再用中模型作為teacher,將知識蒸餾到真正的小模型(如此處的ALBert-Tiny),模型精度最終可以達到88.3%,精度大約有6個點的提升。
七、知識蒸餾需要注意的點
-
溫度T,高溫T。通常模型訓練的時候使用高溫T,而在模型測試和預測階段的時候,是不使用teacher模型的,僅使用student模型進行測試和預測,也就是T在預測階段不使用。
-
MSELoss計算L-soft。不使用上述復雜的L-soft,而使用簡單的均方差損失函數——MSELoss。
-
Hard-Loss加入模型。即便使用了Soflt-Loss,還是需要引入Hard-Loss以及超參數 目的是teacher模型也可能存在無法完全學對的可能,所以在數據質量有保證的情況下,引入學生模型的hard-loss能更好的學會teacher無法學會的知識。實際使用過程中也發現,引入hard-loss很有效果。
-
使用更多損失函數。,Hinton的蒸餾學習使用的是Cross-Entorpy作為損失函數,其實損失函數不止于交叉熵損失函數,包括MSELoss,NLLLoss,HingeLoss等。實驗中,我使用了MSELoss和CELoss做比較,發現二者對于Student模型的效果類似,所以對于不同下游任務,可以使用更貼合的Loss函數,不必局限于CELoss。但是對于蒸餾學習的理解一定要到位,才能更合理的利用Loss。
-
集成學習加入到蒸餾學習中。通常我們不會僅僅使用一個老師,而是使用多個teacher,然后將多個teacher的知識權重相加引入到student模型中。這是將集成學習和蒸餾學習相融合,能讓學生學到更多信息,但是也同時增加了模型訓練的難度(增多了超參數的數量以及集成學習方法的比較),對于初學者不建議使用。
八、知識蒸餾的幾個思考
student loss的必要性
因為老師網絡也有一定的錯誤率,使用ground truth可以有效降低錯誤被傳播給學生網絡的可能。
舉例,老師雖然學識遠遠超過學生,但是他仍然有出錯的可能,而這時候如果學生在老師的教授之外,可以同時參考到標準答案,就可以有效地降低被老師偶爾的錯誤“帶偏”的可能性。
為什么student loss(硬標簽)所占比重比較小的時候,能產生最好的結果
這是一個經驗的結論。一個可能的原因是,由于soft target產生的gradient與hard target產生的gradient之間有與 T 相關的比值。
在同時使用soft target和hard target的時候,需要在soft target之前乘上 T2T^2T2 這個系數,這樣才能保證soft target和hard target貢獻的梯度量基本一致。
能不能直接match logits(不經過softmax)
直接match logits指的是,直接使用softmax層的輸入logits(而不是輸出)作為soft targets,需要最小化的目標函數是Net-T和Net-S的logits之間的平方差。直接上結論: 直接match logits的做法是 T→∞ 的情況下的特殊情形。
總結
- 上一篇: 国科大学习资料--最优化计算方法(王晓)
- 下一篇: 大数据实时处理学期总结