【NLP】BERT蒸馏完全指南|原理/技巧/代码
小朋友,關(guān)于模型蒸餾,你是否有很多問(wèn)號(hào):
蒸餾是什么?怎么蒸BERT?
BERT蒸餾有什么技巧?如何調(diào)參?
蒸餾代碼怎么寫(xiě)?有現(xiàn)成的嗎?
今天rumor就結(jié)合Distilled BiLSTM/BERT-PKD/DistillBERT/TinyBERT/MobileBERT/MiniLM六大經(jīng)典模型,帶大家把BERT蒸餾整到明明白白!
模型蒸餾原理
Hinton在NIPS2014[1]提出了知識(shí)蒸餾(Knowledge Distillation)的概念,旨在把一個(gè)大模型或者多個(gè)模型ensemble學(xué)到的知識(shí)遷移到另一個(gè)輕量級(jí)單模型上,方便部署。簡(jiǎn)單的說(shuō)就是用小模型去學(xué)習(xí)大模型的預(yù)測(cè)結(jié)果,而不是直接學(xué)習(xí)訓(xùn)練集中的label。
在蒸餾的過(guò)程中,我們將原始大模型稱為教師模型(teacher),新的小模型稱為學(xué)生模型(student),訓(xùn)練集中的標(biāo)簽稱為hard label,教師模型預(yù)測(cè)的概率輸出為soft label,temperature(T)是用來(lái)調(diào)整soft label的超參數(shù)。
蒸餾這個(gè)概念之所以work,核心思想是因?yàn)?strong>好模型的目標(biāo)不是擬合訓(xùn)練數(shù)據(jù),而是學(xué)習(xí)如何泛化到新的數(shù)據(jù)。所以蒸餾的目標(biāo)是讓學(xué)生模型學(xué)習(xí)到教師模型的泛化能力,理論上得到的結(jié)果會(huì)比單純擬合訓(xùn)練數(shù)據(jù)的學(xué)生模型要好。
如何蒸餾
蒸餾發(fā)展到今天,有各種各樣的花式方法,我們先從最基本的說(shuō)起。
之前提到學(xué)生模型需要通過(guò)教師模型的輸出學(xué)習(xí)泛化能力,那對(duì)于簡(jiǎn)單的二分類(lèi)任務(wù)來(lái)說(shuō),直接拿教師預(yù)測(cè)的0/1結(jié)果會(huì)與訓(xùn)練集差不多,沒(méi)什么意義,那拿概率值是不是好一些?于是Hinton采用了教師模型的輸出概率q,同時(shí)為了更好地控制輸出概率的平滑程度,給教師模型的softmax中加了一個(gè)參數(shù)T。
有了教師模型的輸出后,學(xué)生模型的目標(biāo)就是盡可能擬合教師模型的輸出,新loss就變成了:
其中CE是交叉熵(Cross-Entropy),y是真實(shí)label,p是學(xué)生模型的預(yù)測(cè)結(jié)果,是蒸餾loss的權(quán)重。這里要注意的是,因?yàn)閷W(xué)生模型要擬合教師模型的分布,所以在求p時(shí)的也要使用一樣的參數(shù)T。另外,因?yàn)樵谇筇荻葧r(shí)新的目標(biāo)函數(shù)會(huì)導(dǎo)致梯度是以前的 ,所以要再乘上,不然T變了的話hard label不減小(T=1),但soft label會(huì)變。
有同學(xué)可能會(huì)疑惑:如果可以擬合prob,那直接擬合logits可以嗎?
當(dāng)然可以,Hinton在論文中進(jìn)行了證明,如果T很大,且logits分布的均值為0時(shí),優(yōu)化概率交叉熵和logits的平方差是等價(jià)的。
BERT蒸餾
在BERT提出后,如何瘦身就成了一個(gè)重要分支。主流的方法主要有剪枝、蒸餾和量化。量化的提升有限,因此免不了采用剪枝+蒸餾的融合方法來(lái)獲取更好的效果。接下來(lái)將介紹BERT蒸餾的主要發(fā)展脈絡(luò),從各個(gè)研究看來(lái),蒸餾的提升一方面來(lái)源于從精調(diào)階段蒸餾->預(yù)訓(xùn)練階段蒸餾,另一方面則來(lái)源于蒸餾最后一層知識(shí)->蒸餾隱層知識(shí)->蒸餾注意力矩陣。
Distilled BiLSTM
Distilled BiLSTM[2]于2019年5月提出,作者將BERT-large蒸餾到了單層的BiLSTM中,參數(shù)量減少了100倍,速度提升了15倍,效果雖然比BERT差不少,但可以和ELMo打成平手。
Distilled BiLSTM的教師模型采用精調(diào)過(guò)的BERT-large,學(xué)生模型采用BiLSTM+ReLU,蒸餾的目標(biāo)是hard labe的交叉熵+logits之間的MSE(作者經(jīng)過(guò)實(shí)驗(yàn)發(fā)現(xiàn)MSE比上文的更好)。
同時(shí)因?yàn)槿蝿?wù)數(shù)據(jù)有限,作者基于以下規(guī)則進(jìn)行了10+倍的數(shù)據(jù)擴(kuò)充:
用[MASK]隨機(jī)替換單詞
基于POS標(biāo)簽替換單詞
從樣本中隨機(jī)取出n-gram作為新的樣本
但由于沒(méi)有消融實(shí)驗(yàn),無(wú)法知道數(shù)據(jù)增強(qiáng)給模型提升了多少最終效果。
BERT-PKD (EMNLP2019)
既然BERT有那么多層,是不是可以蒸餾中間層的知識(shí),讓學(xué)生模型更好地?cái)M合呢?
BERT-PKD[3]不同于之前的研究,提出了Patient Knowledge Distillation,即從教師模型的中間層提取知識(shí),避免在蒸餾最后一層時(shí)擬合過(guò)快的現(xiàn)象(有過(guò)擬合的風(fēng)險(xiǎn))。
對(duì)于中間層的蒸餾,作者采用了歸一化之后MSE,稱為PT loss。
教師模型采用精調(diào)好的BERT-base,學(xué)生模型一個(gè)6層一個(gè)3層。為了初始化一個(gè)更好的學(xué)生模型,作者提出了兩種策略,一種是PKD-skip,即用BERT-base的第[2,4,6,8,10]層,另一種是PKD-last,采用第[7,8,9,10,11]層。最終實(shí)驗(yàn)顯示PKD-skip要略好一點(diǎn)點(diǎn)(<0.01)。
DistillBERT (NIPS2019)
之前的工作都是對(duì)精調(diào)后的BERT進(jìn)行蒸餾,學(xué)生模型學(xué)到的都是任務(wù)相關(guān)的知識(shí)。HuggingFace則提出了DistillBERT[4],在預(yù)訓(xùn)練階段進(jìn)行蒸餾。將尺寸減小了40%,速度提升60%,效果好于BERT-PKD,為教師模型的97%。
DistillBERT的教師模型采用了預(yù)訓(xùn)練好的BERT-base,學(xué)生模型則是6層transformer,采用了PKD-skip的方式進(jìn)行初始化。和之前蒸餾目標(biāo)不同的是,為了調(diào)整教師和學(xué)生的隱層向量方向,作者新增了一個(gè)cosine embedding loss,蒸餾最后一層hidden的。最終損失函數(shù)由MLM loss、教師-學(xué)生最后一層的交叉熵、隱層之間的cosine loss組成。從消融實(shí)驗(yàn)可以看出,MLM loss對(duì)于學(xué)生模型的表現(xiàn)影響較小,同時(shí)初始化也是影響效果的重要因素:
TinyBERT(EMNLP2019)
既然精調(diào)階段、預(yù)訓(xùn)練階段都分別被蒸餾過(guò)了,理論上兩步聯(lián)合起來(lái)的效果可能會(huì)更好。
TinyBERT[5]就提出了two-stage learning框架,分別在預(yù)訓(xùn)練和精調(diào)階段蒸餾教師模型,得到了參數(shù)量減少7.5倍,速度提升9.4倍的4層BERT,效果可以達(dá)到教師模型的96.8%,同時(shí)這種方法訓(xùn)出的6層模型甚至接近BERT-base,超過(guò)了BERT-PKD和DistillBERT。
TinyBERT的教師模型采用BERT-base。作者參考其他研究的結(jié)論,即注意力矩陣可以捕獲到豐富的知識(shí),提出了注意力矩陣的蒸餾,采用教師-學(xué)生注意力矩陣logits的MSE作為損失函數(shù)(這里不取attention prob是實(shí)驗(yàn)表明前者收斂更快)。另外,作者還對(duì)embedding進(jìn)行了蒸餾,同樣是采用MSE作為損失。
于是整體的loss計(jì)算可以用下式表示:
其中m表示層數(shù)。表示教師-學(xué)生最后一層logits的交叉熵。
最后的實(shí)驗(yàn)中,預(yù)訓(xùn)練階段只對(duì)中間層進(jìn)行了蒸餾;精調(diào)階段則先對(duì)中間層蒸餾20個(gè)epochs,再對(duì)最后一層蒸餾3個(gè)epochs。
上圖是各個(gè)階段的消融實(shí)驗(yàn)。GD(General Distillation)表示預(yù)訓(xùn)練蒸餾,TD(Task Distillation)表示精調(diào)階段蒸餾,DA(Data Augmentation)表示數(shù)據(jù)增強(qiáng),主要用于精調(diào)階段。從消融實(shí)驗(yàn)來(lái)看GD帶來(lái)的提升不如TD或者DA,TD和DA對(duì)最終結(jié)果的影響差不多(有種蒸了這么半天還不如多標(biāo)點(diǎn)數(shù)據(jù)的感覺(jué)=.=)。
MobileBERT(ACL2020)
前文介紹的模型都是層次剪枝+蒸餾的操作,MobileBERT[6]則致力于減少每層的維度,在保留24層的情況下,減少了4.3倍的參數(shù),速度提升5.5倍,在GLUE上平均只比BERT-base低了0.6個(gè)點(diǎn),效果好于TinyBERT和DistillBERT。
MobileBERT壓縮維度的主要思想在于bottleneck機(jī)制,如下圖所示:
其中a是標(biāo)準(zhǔn)的BERT,b是加入bottleneck的BERT-large,作為教師模型,c是加入bottleneck的學(xué)生模型。Bottleneck的原理是在transformer的輸入輸出各加入一個(gè)線性層,實(shí)現(xiàn)維度的縮放。對(duì)于教師模型,embedding的維度是512,進(jìn)入transformer后擴(kuò)大為1024,而學(xué)生模型則是從512縮小至128,使得參數(shù)量驟減。
另外,作者發(fā)現(xiàn)在標(biāo)準(zhǔn)BERT中,多頭注意力機(jī)制MHA和非線性層FFN的參數(shù)比為1:2,這個(gè)參數(shù)比相比其他比例更好。所以為了維持比例,會(huì)在學(xué)生模型中多加幾層FFN。
MobileBERT的蒸餾中,作者先用b的結(jié)構(gòu)預(yù)訓(xùn)練一個(gè)BERT-large,再蒸餾到24層學(xué)生模型中。蒸餾的loss有多個(gè):
Feature Map Transfer:隱層的MSE
Attention Transfer:注意力矩陣的KL散度
Pre-training Distillation:
同時(shí)作者還研究了三種不同的蒸餾策略:直接蒸餾所有層、先蒸餾中間層再蒸餾最后一層、逐層蒸餾。如下圖:
最后的結(jié)論是逐層蒸餾效果最好,但差距最大才0.5個(gè)點(diǎn),性價(jià)比有些低了。。
MobileBERT還有一點(diǎn)不同于之前的TinyBERT,就是預(yù)訓(xùn)練階段蒸餾之后,作者直接在MobileBERT上用任務(wù)數(shù)據(jù)精調(diào),而不需要再進(jìn)行精調(diào)階段的蒸餾,方便了很多。
MiniLM
之前的各種模型基本上把BERT里面能蒸餾的都蒸了個(gè)遍,但MiniLM[7]還是找到了新的藍(lán)海——蒸餾Value-Value矩陣:
Value-Relation Transfer可以讓學(xué)生模型更深入地模仿教師模型,實(shí)驗(yàn)表明可以帶來(lái)1-2個(gè)點(diǎn)的提升。同時(shí)作者考慮到學(xué)生模型的層數(shù)、維度都可能和教師模型不同,在實(shí)驗(yàn)中只蒸餾最后一層,并且只蒸餾這兩個(gè)矩陣的KL散度,簡(jiǎn)直是懶癌福音。
另外,作者還引入了助教機(jī)制。當(dāng)學(xué)生模型的層數(shù)、維度都小很多時(shí),先用一個(gè)維度小但層數(shù)和教師模型一致的助教模型蒸餾,之后再把助教的知識(shí)傳遞給學(xué)生。
最終采用BERT-base作為教師,實(shí)驗(yàn)下來(lái)6層的學(xué)生模型比起TinyBERT和DistillBERT好了不少,基本是20年性價(jià)比數(shù)一數(shù)二的蒸餾了。
BERT蒸餾技巧
介紹了BERT蒸餾的幾個(gè)經(jīng)典模型之后,真正要上手前還是要把幾個(gè)問(wèn)題都考慮清楚,下面就來(lái)討論一些蒸餾中的變量。
剪層還是減維度?
這個(gè)選擇取決于是預(yù)訓(xùn)練蒸餾還是精調(diào)蒸餾。預(yù)訓(xùn)練蒸餾的數(shù)據(jù)比較充分,可以參考MiniLM、MobileBERT或者TinyBERT那樣進(jìn)行剪層+維度縮減,如果想蒸餾中間層,又不想像MobileBERT一樣增加bottleneck機(jī)制重新訓(xùn)練一個(gè)教師模型的話可以參考TinyBERT,在計(jì)算隱層loss時(shí)增加一個(gè)線性變換,擴(kuò)大學(xué)生模型的維度:
對(duì)于針對(duì)某項(xiàng)任務(wù)、只想蒸餾精調(diào)后BERT的情況,則推薦進(jìn)行剪層,同時(shí)利用教師模型的層對(duì)學(xué)生模型進(jìn)行初始化。從BERT-PKD以及DistillBERT的結(jié)論來(lái)看,采用skip(每隔n層選一層)的初始化策略會(huì)優(yōu)于只選前k層或后k層。
用哪個(gè)Loss?
看完原理后相信大家也發(fā)現(xiàn)了,基本上每個(gè)模型蒸餾都用的是不同的損失函數(shù),CE、KL、MSE、Cos魔幻組合,自己蒸餾時(shí)都不知道選哪個(gè)好。。于是rumor我強(qiáng)行梳理了一番,大家可以根據(jù)自己的任務(wù)目標(biāo)挑選:
對(duì)于hard label,使用KL和CE是一樣的,因?yàn)?#xff0c;訓(xùn)練集不變時(shí)label分布是一定的。但對(duì)于soft label則不同了,不過(guò)表中不少模型還是采用了CE,只有Distilled BiLSTM發(fā)現(xiàn)更好。個(gè)人認(rèn)為可以CE/MSE/KL都試一下,但MSE有個(gè)好處是可以避免T的調(diào)參。
中間層輸出的蒸餾,大多數(shù)模型都采用了MSE,只有DistillBERT加入了cosine loss來(lái)對(duì)齊方向。
注意力矩陣的蒸餾loss則比較統(tǒng)一,如果要蒸餾softmax之前的attention logits可以采用MSE,之后的attention prob可以用KL散度。
T和如何設(shè)置?
超參數(shù)主要控制soft label和hard label的loss比例,Distilled BiLSTM在實(shí)驗(yàn)中發(fā)現(xiàn)只使用soft label會(huì)得到最好的效果。個(gè)人建議讓soft label占比更多一些,一方面是強(qiáng)迫學(xué)生更多的教師知識(shí),另一方面實(shí)驗(yàn)證實(shí)soft target可以起到正則化的作用,讓學(xué)生模型更穩(wěn)定地收斂。
超參數(shù)T主要控制預(yù)測(cè)分布的平滑程度,TinyBERT實(shí)驗(yàn)發(fā)現(xiàn)T=1更好,BERT-PKD的搜索空間則是{5, 10, 20}。因此建議在1~20之間多嘗試幾次,T越大越能學(xué)到teacher模型的泛化信息。比如MNIST在對(duì)2的手寫(xiě)圖片分類(lèi)時(shí),可能給2分配0.9的置信度,3是1e-6,7是1e-9,從這個(gè)分布可以看出2和3有一定的相似度,這種時(shí)候可以調(diào)大T,讓概率分布更平滑,展示teacher更多的泛化能力。
需要逐層蒸餾嗎?
如果不是特別追求零點(diǎn)幾個(gè)點(diǎn)的提升,建議無(wú)腦一次性蒸餾,從MobileBERT來(lái)看這個(gè)操作性價(jià)比太低了。
蒸餾代碼實(shí)戰(zhàn)
目前Pytorch版本的模型蒸餾有一個(gè)非常贊的開(kāi)源工具TextBrewer[8],在它的src/textbrewer/losses.py文件下可以看到各種loss的實(shí)現(xiàn)。
最后輸出層的CE/KL/MSE loss比較簡(jiǎn)單,只需要將兩者的logits除temperature之后正常計(jì)算就可以了,以CE為例:
def?kd_ce_loss(logits_S,?logits_T,?temperature=1):'''Calculate?the?cross?entropy?between?logits_S?and?logits_T:param?logits_S:?Tensor?of?shape?(batch_size,?length,?num_labels)?or?(batch_size,?num_labels):param?logits_T:?Tensor?of?shape?(batch_size,?length,?num_labels)?or?(batch_size,?num_labels):param?temperature:?A?float?or?a?tensor?of?shape?(batch_size,?length)?or?(batch_size,)'''if?isinstance(temperature,?torch.Tensor)?and?temperature.dim()?>?0:temperature?=?temperature.unsqueeze(-1)beta_logits_T?=?logits_T?/?temperaturebeta_logits_S?=?logits_S?/?temperaturep_T?=?F.softmax(beta_logits_T,?dim=-1)loss?=?-(p_T?*?F.log_softmax(beta_logits_S,?dim=-1)).sum(dim=-1).mean()return?loss對(duì)于hidden MSE的蒸餾loss,則需要去除被mask的部分,另外如果維度不一致,需要額外加一個(gè)線性變換,TextBrewer默認(rèn)輸入維度是一致的:
def?hid_mse_loss(state_S,?state_T,?mask=None):'''*?Calculates?the?mse?loss?between?`state_S`?and?`state_T`,?which?are?the?hidden?state?of?the?models.*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.*?If?the?hidden?sizes?of?student?and?teacher?are?different,?'proj'?option?is?required?in?`inetermediate_matches`?to?match?the?dimensions.:param?torch.Tensor?state_S:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*):param?torch.Tensor?state_T:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*):param?torch.Tensor?mask:????tensor?of?shape??(*batch_size*,?*length*)'''if?mask?is?None:loss?=?F.mse_loss(state_S,?state_T)else:mask?=?mask.to(state_S)valid_count?=?mask.sum()?*?state_S.size(-1)loss?=?(F.mse_loss(state_S,?state_T,?reduction='none')?*?mask.unsqueeze(-1)).sum()?/?valid_countreturn?loss蒸餾attention矩陣則也要考慮mask,但注意這里要處理的維度是N*N:
def?att_mse_loss(attention_S,?attention_T,?mask=None):'''*?Calculates?the?mse?loss?between?`attention_S`?and?`attention_T`.*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.:param?torch.Tensor?logits_S:?tensor?of?shape??(*batch_size*,?*num_heads*,?*length*,?*length*):param?torch.Tensor?logits_T:?tensor?of?shape??(*batch_size*,?*num_heads*,?*length*,?*length*):param?torch.Tensor?mask:?tensor?of?shape??(*batch_size*,?*length*)'''if?mask?is?None:attention_S_select?=?torch.where(attention_S?<=?-1e-3,?torch.zeros_like(attention_S),?attention_S)attention_T_select?=?torch.where(attention_T?<=?-1e-3,?torch.zeros_like(attention_T),?attention_T)loss?=?F.mse_loss(attention_S_select,?attention_T_select)else:mask?=?mask.to(attention_S).unsqueeze(1).expand(-1,?attention_S.size(1),?-1)?#?(bs,?num_of_heads,?len)valid_count?=?torch.pow(mask.sum(dim=2),2).sum()loss?=?(F.mse_loss(attention_S,?attention_T,?reduction='none')?*?mask.unsqueeze(-1)?*?mask.unsqueeze(2)).sum()?/?valid_countreturn?loss最后是只在DistillBERT中出現(xiàn)的cosine loss,可以直接使用pytorch的默認(rèn)接口:
def?cos_loss(state_S,?state_T,?mask=None):'''*?Computes?the?cosine?similarity?loss?between?the?inputs.?This?is?the?loss?used?in?DistilBERT,?see?`DistilBERT?<https://arxiv.org/abs/1910.01108>`_*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.*?If?the?hidden?sizes?of?student?and?teacher?are?different,?'proj'?option?is?required?in?`inetermediate_matches`?to?match?the?dimensions.:param?torch.Tensor?state_S:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*):param?torch.Tensor?state_T:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*):param?torch.Tensor?mask:????tensor?of?shape??(*batch_size*,?*length*)'''if?mask?is??None:state_S?=?state_S.view(-1,state_S.size(-1))state_T?=?state_T.view(-1,state_T.size(-1))else:mask?=?mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype)?#(bs,len,dim)state_S?=?torch.masked_select(state_S,?mask).view(-1,?mask.size(-1))??#(bs?*?select,?dim)state_T?=?torch.masked_select(state_T,?mask).view(-1,?mask.size(-1))??#?(bs?*?select,?dim)target?=?state_S.new(state_S.size(0)).fill_(1)loss?=?F.cosine_embedding_loss(state_S,?state_T,?target,?reduction='mean')return?loss關(guān)于更多的蒸餾實(shí)戰(zhàn)經(jīng)驗(yàn),可以參考知乎@邱震宇同學(xué)的模型蒸餾技巧小結(jié)[9]。
總結(jié)
短暫的學(xué)習(xí)就要結(jié)束了,蒸餾雖然費(fèi)勁,但確實(shí)是目前小模型提升效果的主要方法之一,在很多研究中都有用到。另外,模型蒸餾有一個(gè)好處是可以利用大批量的無(wú)監(jiān)督數(shù)據(jù),只要能找到任務(wù)相關(guān)的,就可以蒸餾提升模型的泛化能力。標(biāo)注數(shù)據(jù)少的同學(xué)還等什么?快去試試叭!
往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專(zhuān)輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專(zhuān)輯 獲取一折本站知識(shí)星球優(yōu)惠券,復(fù)制鏈接直接打開(kāi): https://t.zsxq.com/y7uvZF6 本站qq群704220115。加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【NLP】BERT蒸馏完全指南|原理/技巧/代码的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【Python基础】Python 打基础
- 下一篇: 【收藏】机器学习入门的常见问题集(文末送