知识蒸馏小总结
定義
知識蒸餾是一種模型壓縮方法,是一種基于“教師-學生網絡思想”的訓練方法,由于其簡單,有效,在工業界被廣泛應用。
更簡單的理解:用一個已經訓練好的模型去“教”另一個模型去學習,這兩個模型通常稱為老師-學生模型。
用一個小例子來加深理解:
相關知識
pytorch中的損失函數:
Softmax:將一個數值序列映射到概率空間
# Softmax import torch import torch.nn.functional as F# torch.nn是pytorch中自帶的一個函數庫,里面包含了神經網絡中使用的一些常用函數, # 如具有可學習參數的nn.Conv2d(),nn.Linear()和不具有可學習的參數(如ReLU,pool,DropOut等)(后面這幾個是在nn.functional中) # 在圖片分類問題中,輸入m張圖片,輸出一個m*N的Tensor,其中N是分類類別總數。 # 比如輸入2張圖片,分三類,最后的輸出是一個2*3的Tensor,舉個例子: # torch.randn:用來生成隨機數字的tensor,這些隨機數字滿足標準正態分布(0~1) output = torch.randn(2, 3) print(output) # tensor([[-1.1639, 0.2698, 1.5513], # [-1.0839, 0.3102, -0.8798]]) # 第1,2行分別是第1,2張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。 # 可以看出模型認為第一張為豬,第二張為狗。 然后對每一行使用Softmax,這樣可以得到每張圖片的概率分布。 print(F.softmax(output,dim=1)) # tensor([[0.1167, 0.1955, 0.6878], # [0.8077, 0.0990, 0.0933]])log_Softmax:在Softmax的基礎上進行取對數運算
# log_softmax print(F.log_softmax(output,dim=1)) print(torch.log(F.softmax(output,dim=1))) tensor([[-1.8601, -0.7688, -0.9655],[-0.9205, -1.1949, -1.2075]]) tensor([[-1.8601, -0.7688, -0.9655],[-0.9205, -1.1949, -1.2075]]) # 結果是一致的NLLLoss:對log_softmax和one-hot編碼進行運算
# NLLLoss print(F.nll_loss(torch.tensor([[-1.2, -0.03, -0.5]]), torch.tensor([0])))注:Tensor是張量,所以至少為[[]]!!!
# 通常我們結合 log_softmax 和 nll_loss一起用 output = torch.tensor([[1.2,3,2.6]]) target = torch.tensor([0]) print("output為[[1.2,3,2.6]],若target為第一個,nll_loss為:",F.nll_loss(output,target)) target = torch.tensor([1]) print("output為[[1.2,3,2.6]],若target為第二個,nll_loss為:",F.nll_loss(output,target)) target = torch.tensor([2]) print("output為[[1.2,3,2.6]],若target為第二個,nll_loss為:",F.nll_loss(output,target))輸出結果: output為[[1.2,3,2.6]],若target為第一個,nll_loss為: tensor(-1.2000) output為[[1.2,3,2.6]],若target為第二個,nll_loss為: tensor(-3.) output為[[1.2,3,2.6]],若target為第二個,nll_loss為: tensor(-2.6000)CrossEntropy:衡量兩個概率分布的差別
output = torch.tensor([[1.2,3,2.6]]) log_softmax_output = F.log_softmax(output,dim=1) target = torch.tensor([0]) print(F.nll_loss(log_softmax_output,target))print(F.cross_entropy(output,target)) # 交叉熵自帶softmax輸出結果: tensor(2.4074) tensor(2.4074)圖解KD
圖中貓的圖片的one-hot編碼先輸入到Teacher網絡中進行訓練得到q’,在通過蒸餾得到q’’,最后得到soft targets,然后再把貓的圖片輸入到Student網絡中,得到hard targets并計算損失函數,最后和來自Teacher網絡預測結果的損失函數相加得到最后的損失函數。
知識蒸餾過程
知識蒸餾應用場景
知識蒸餾和遷移學習的基本區別
遷移學習:是從一個領域獲取得模型應用到別的領域的學習
知識蒸餾:是在同一個領域中,從大模型遷移到小模型上的學習
總結
- 上一篇: QQ浏览器调试解决方案
- 下一篇: html中只显示农历的完整代码,很全的显