论文阅读|DeiT
Training data-efficient image transformers & distillation through attention
論文鏈接:https://export.arxiv.org/pdf/2012.12877
代碼:https://github.com/facebookresearch/deit
摘要
純基于注意力的神經網絡被證明可以解決圖像分類等圖像理解任務,但是這些高性能的網絡結構通常需要使用大型的基礎設施預先訓練了數億個圖像,因此限制了他們的采用。
為此,對于這種設計龐大的預訓練量,作者提出了一種convolution-free transformers的結構,只在Imagenet上進行訓練,就具有競爭力。在不需要其他額外數據進行預訓練的情況下,在ImageNet上達到了top-1 accuracy 達到83.1%的效果。
此外,作者還引入了teacher-student策略,依賴于一個蒸餾標記(distillation token),確保學生通過注意力從老師那里學習。當卷積網絡作為teacher時,效果達到了85.2%的準確率
DeiT能取得更好效果的方法:
- better hyperparameter更好的超參數設置(模型初始化,learning-rate等設置),
- data augmentation(多種數據增強方式)
- distillation(知識蒸餾)
Introduction
ViT的問題:
- ViT需要大量的GPU資源
- ViT的預訓練數據集JFT-300M并沒有公開
- 超參數設置不好很容易導致訓練效果差
- 只用ImageNet訓練準確率沒有很好
對于VIT訓練數據巨大,超參數難設置導致訓練效果不好的問題,提出了DeiT。
DeiT : Data-efficient image Transformers
DeiT的模型和VIT的模型幾乎是相同的,可以理解為本質上是在訓一個VIT。
DeiT特點:
- DeiT不包含卷積層,可以在沒有外部數據的情況下實現與ImageNet上的最新技術相媲美的結果。
- 引入了一種基于distillation token的新蒸餾過程,它扮演著與class token相同的角色,只不過它的目的是學習教師網絡中的預測結果,兩個標記都通過注意力在轉換器中交互。
- 通過蒸餾,圖像transformers從一個convnet學到的比從另一個性能相當的transformers學到的更多。
- 在Imagenet上預先學習的模型在轉移到不同的下游任務時是有競爭力的
Related work
Knowledge Distillation 知識蒸餾
參考資料:(58條消息) 深度學習之知識蒸餾(Knowledge Distillation)_AndyJ的學習之旅-CSDN博客_知識蒸餾溫度
簡單來說就是用teacher模型去訓練student模型,通常teacher模型更大而且已經訓練好了,student模型是我們當前需要訓練的模型。在這個過程中,teacher模型是不訓練的。
軟蒸餾 soft distillation
?當teacher模型和student模型拿到相同的圖片時,都進行各自的前向,這時teacher模型就拿到了具有分類信息的feature,在進行softmax之前先除以一個參數?,叫做temperature(蒸餾溫度),然后softmax得到soft labels(區別于one-hot形式的hard-label)。
student模型也是除以同一個?,然后softmax得到一個soft-prediction,我們希望student模型的soft-prediction和teacher模型的soft labels盡量接近,使用KLDivLoss進行兩者之間的差距度量,計算一個對應的損失teacher loss。
在訓練的時候,我們是可以拿的到訓練圖片的真實的ground truth(hard label)的,可以看到上面圖中student模型下面一路,就是預測結果和真是標簽之間計算交叉熵crossentropy。
交叉熵:損失函數|交叉熵損失函數 - 知乎 (zhihu.com)
然后兩路計算的損失:KLDivLoss和CELoss,按照一個加權關系計算得到一個總損失total loss,反向修改參數的時候這個teacher模型是不做訓練的,只依據total loss訓練student模型。
其中表示的是教師網絡的輸出概率,表示學生網絡的輸出概率,表示蒸餾溫度,λ表示Kullback-Leibler散度損失與交叉熵損失之間的權重因子,y表示真實標簽,ψ表示softmax函數
公式很容易可以理解,loss為學生網絡與真實標簽的損失加上學生網絡輸出值與教師網絡輸出值的標簽分布差異。一方面希望學生網絡的輸出值與真實標簽相近,同時還希望其與教師網絡的輸出分布相近,這樣才可以學習到教師網絡對某些錯誤數據與正確數據的相識情況。
?
硬蒸餾 hard diatillation
其中:,一方面使得學生網絡與真實標簽的損失最小,同時也希望與教師網絡得出來的標簽損失最小,這兩個損失各占一半的權重。
對于給定的圖像,與教師相關的硬標簽可能會根據具體的數據增強而改變,而這種選擇比傳統的選擇更好,教師預測與真實labely扮演相同的角色。還要注意,硬標簽也可以通過標簽平滑轉換為軟標簽。
軟蒸餾是限制student和teacher的模型輸出類別分布盡可能接近,而硬蒸餾是限制兩種模型輸出的類別標簽盡可能接近。
KLDivloss
KL散度,又叫相對熵,用于衡量兩個分布(連續分布和離散分布)之間的距離,在knowledge distillation中,兩個分布為teacher模型和student模型的softmax輸出。
當兩個分布很相近時候,對應class的預測值就會很接近,取log之后的差值就會很小,KL散度就很小。當兩個分布完全一致時候,KL散度就等于0。
transformer中加入蒸餾——distillation token
在VIT中時使用class tokens去做分類的,相當于是一個額外的patch,這個patch去學習和別的patch之間的關系,然后連classifier,計算CELoss。在DeiT中為了做蒸餾,又額外加一個distill token,這個distill token也是去學和其他tokens之間的關系,然后連接teacher model計算KLDivLoss,那CELoss和KLDivLoss共同加權組合成一個新的loss取指導student model訓練(知識蒸餾中teacher model不訓練)。
在預測階段,class token和distill token分別產生一個結果,然后將其加權(分別0.5),再加在一起,得到最終的結果做預測。
在patches中加入與class token類似的distillation token,兩者的通過網絡時的計算方式相同,區別在于class token目標是重現ground truth標簽,而distillation token目標是重現教師模型的預測,Distillation token讓模型從教師模型輸出中學習,文章發現:
- 最初class token和distillation token區別很大,余弦相似度為0.06
- 隨著class 和 distillation embedding互相傳播和學習,通過網絡逐漸變得相似,到最后一層,余弦相似度為0.93,相似但不相同
- 當用一個class token替換distillation token時,兩個class token輸出的余弦相似度為0.999,網絡性能與一個class token相近,而加入distillation token的網絡性能明顯提升。這表明distillation token的設定是有效的。
Experiments
DeiT不同參數
定義了與ViT-B參數相同的DeiT-B模型,和更小的DeiT-S、DeiT-Ti模型,區別在于heads數目和embedding dimension不同。超參數如下:
?teacher模型的選擇
實驗發現RegNetY-16GF是效果最好的教師模型,后續實驗默認選擇。
CNN效果更好,這可能是因為transformer可以學到CNN的歸納假設。CNN是有inductive bias的,例如局部感受野,參數共享等,這些設計比較適應于圖像任務,這里將CNN作為teacher,可以通過蒸餾,使得Transformer學習得到CNN的inductive bias,從而提升Transformer對圖像任務的處理能力。
同時還可以發現,學生網絡可以取得超越老師的性能,能夠在準確率和吞吐量權衡方面做的更好。
m小標表示使用了蒸餾策略的網絡模型。↑384表示student在224*224圖像上進行預訓練,然后在384*384圖像上進行fine-tune。
蒸餾策略的選擇
硬蒸餾的性能比軟蒸餾更好。
在pretrain上測試的時候,distillation token和class token兩個一起用性能更佳,這表明兩個token提供了對分類有用的互補信息。只用一個的時候,distillation token性能略好于class token,這可能是因為distillation token里有更多從CNN中學到的歸納假設
與其他模型性能對比
?訓練策略和消融實驗
初始化和超參數?
參數初始化方式:truncated normal distribution(截斷標準分布)
soft蒸餾參數:= 3 , = 0.1
數據增強
總結