【论文学习】ICLR2021,鲁棒早期学习法:抑制记忆噪声标签ROBUST EARLY-LEARNING: HINDERING THE MEMORIZATION OF NOISY LABELS
??論文來自ICLR2021,作者是悉尼大學(xué)的Xiaobo Xia博士。論文基于早停和彩票假說,提出了一種處理標(biāo)簽噪聲問題的新方法。我就論文要點(diǎn)學(xué)習(xí)整理,目前還沒有找到開源代碼,我實(shí)現(xiàn)了一份在本文中給出。我對論文中部分試驗(yàn)復(fù)現(xiàn),并補(bǔ)充進(jìn)行一些新試驗(yàn)。
??論文鏈接
文章目錄
- 一、理論要點(diǎn)
- 二、公式推導(dǎo)
- 三、效果對比
- 四、我的代碼及部分試驗(yàn)復(fù)現(xiàn)
- 1,核心代碼
- 2,我的試驗(yàn)
- 2.1,不同噪聲率下觀察“早停”的作用
- 2.2,不同τ\tauτ參數(shù)下觀察“彩票假說”現(xiàn)象
- 2.3,不同噪聲率和不同τ\tauτ參數(shù)下觀察本文算法去噪效果
- 2.4,算法局部修改試驗(yàn)
- 2.4.1 (1?τ1-\tau1?τ)
- 2.4.2 L1正則
- 2.4.3 gig_{i}gi?
- 五、讀后感
一、理論要點(diǎn)
這篇文章基于兩點(diǎn)主要理論:一是深度網(wǎng)絡(luò)會先記憶標(biāo)簽清晰的訓(xùn)練數(shù)據(jù),然后記憶標(biāo)簽有噪聲的訓(xùn)練數(shù)據(jù)。因此,用早停法學(xué)習(xí)可抑制噪聲標(biāo)簽。二是彩票假說指出深度網(wǎng)絡(luò)中只有部分參數(shù)對模型起作用,本文因此認(rèn)為只有部分參數(shù)對擬合干凈標(biāo)簽有用,稱之為關(guān)鍵參數(shù),而其他參數(shù)則傾向于擬合噪聲標(biāo)簽,稱之為非關(guān)鍵參數(shù)。在每次迭代中,對不同的參數(shù)執(zhí)行不同的更新規(guī)則以逐漸使非關(guān)鍵參數(shù)歸零,以此抑制噪聲標(biāo)簽發(fā)揮作用。二、公式推導(dǎo)
文中總共有以下6個(gè)公式:
min L(W;S)L(\mathcal{W};S)L(W;S) = min1n∑i=1nL(W;(xi,yi))+λ∥W∥1\frac{1}{n}\sum \limits_{i=1} ^{n}L(\mathcal{W};(x_{i},y_{i})) + \lambda\begin{Vmatrix}\mathcal{W}\end{Vmatrix}_{1}n1?i=1∑n?L(W;(xi?,yi?))+λ∥∥?W?∥∥?1? ???????(1)
W(k+1)←W(k)?η(?L(W(k);S?)?W(k)+λsgn(W(k)))\mathcal{W}(k+1)\leftarrow\mathcal{W}(k) - \eta(\frac{\partial L(\mathcal{W}(k);S^{*})}{\partial\mathcal{W}(k)}+\lambda sgn(\mathcal{W}(k)))W(k+1)←W(k)?η(?W(k)?L(W(k);S?)?+λsgn(W(k)))??????(2)
gi=∣?L(Wi;S)×Wi∣,i∈[m]g_{i}=|\nabla L(\tiny W_{i}\normalsize ;S) \times \tiny W_{i}\normalsize |, i\in[m]gi?=∣?L(Wi?;S)×Wi?∣,i∈[m]?????????????????(3)
mc=(1?τ)mm_{c}=(1-\tau)mmc?=(1?τ)m ????????????????????????(4)
Wc(k+1)←Wc(k)?η((1?τ)?L(Wc(k);S?~)?Wc(k)+λsgn(Wc(k)))\mathcal{W}_{c}(k+1)\leftarrow\mathcal{W}_{c}(k) - \eta((1-\tau)\frac{\partial L(\mathcal{W}_{c}(k);\tilde{S^{*}})}{\partial\mathcal{W}_{c}(k)}+\lambda sgn(\mathcal{W}_{c}(k)))Wc?(k+1)←Wc?(k)?η((1?τ)?Wc?(k)?L(Wc?(k);S?~)?+λsgn(Wc?(k))) (5)
Wn(k+1)←Wn(k)?ηλsgn(Wn(k))\mathcal{W}_{n}(k+1)\leftarrow\mathcal{W}_{n}(k) - \eta \lambda sgn(\mathcal{W}_{n}(k))Wn?(k+1)←Wn?(k)?ηλsgn(Wn?(k)) ??????????? (6)
考慮給損失函數(shù)加入一個(gè)l1正則項(xiàng),如式(1);
根據(jù)式(1)的損失函數(shù),使用SGD方式更新權(quán)重,如式(2);
對于任一個(gè)參數(shù)Wi∈Wm\tiny W_{i}\normalsize \in {\mathcal{W}^{m}}Wi?∈Wm,根據(jù)式(3)計(jì)算一個(gè)參考量gig_{i}gi?,根據(jù)gig_{i}gi?對W\mathcal{W}W排序。根據(jù)式(4)計(jì)算得到關(guān)鍵參數(shù)的個(gè)數(shù)為mcm_{c}mc?個(gè),然后W\mathcal{W}W排序考前的mcm_{c}mc?個(gè)參數(shù)就是關(guān)鍵參數(shù)Wc\mathcal{W}_{c}Wc?,其余參數(shù)為非關(guān)鍵參數(shù)Wn\mathcal{W}_{n}Wn?;
對于關(guān)鍵參數(shù)按照(5)式更新,注意梯度乘上了一個(gè)衰減系數(shù)(1?τ1-\tau1?τ),作者說這是為了防止訓(xùn)練過程中過度自信下降。(對此不是很理解)
對于非關(guān)鍵參數(shù)按照(6)式更新,此時(shí)把梯度置零,只保留了正則化項(xiàng),這會導(dǎo)致這些非關(guān)鍵參數(shù)逐漸縮小直到接近于0而失去作用。
其中公式(3)比較難理解,為什么用這個(gè)指標(biāo)來判斷哪些是關(guān)鍵參數(shù)呢?原文的解釋如下:
構(gòu)造一個(gè)函數(shù)G(t)=L(tW;S)G(t)=L(\mathcal{tW};S)G(t)=L(tW;S),則
G′(t)=?L(tW;S)TWG'(t)=\nabla L(\mathcal{tW};S)^{T}\mathcal{W}G′(t)=?L(tW;S)TW,
令t=1t=1t=1,有:
G′(1)=?L(W;S)TW=<?L(W;S),W>G'(1)=\nabla L(\mathcal{W};S)^{T}\mathcal{W}=<\nabla L(\mathcal{W};S),\mathcal{W}>G′(1)=?L(W;S)TW=<?L(W;S),W>(<>表示內(nèi)積)
滿足最優(yōu)化條件時(shí),?L(W;S)=0\nabla L(\mathcal{W};S)=0?L(W;S)=0,因此G′(1)=0G'(1)=0G′(1)=0,
由G′(1)=0G'(1)=0G′(1)=0可得到(3)式
說實(shí)話,這個(gè)部分我沒有看懂,有理解的小伙伴可以講一講。
三、效果對比
??作者指出由于本文的主要目的是提出一個(gè)新的概念,并且本文沒有使用多種綜合措施,所以效果趕不上該領(lǐng)域在2020年的兩個(gè)SOTA方法:DivideMix和SELF,除了這兩個(gè)之外,本文方法比其他模型的效果都好。作者進(jìn)行了大量對比試驗(yàn),其中在MNIST、F-MNIST、CIFAR-10、CIFAR-100這四個(gè)數(shù)據(jù)集上的試驗(yàn)如表1。
??作者隨后又在Food-101和WebVision這兩個(gè)數(shù)據(jù)集上進(jìn)行了試驗(yàn),結(jié)論類似。
??作者又進(jìn)行了消融試驗(yàn),試驗(yàn)發(fā)現(xiàn)模型效果對參數(shù)τ\tauτ不敏感。
四、我的代碼及部分試驗(yàn)復(fù)現(xiàn)
1,核心代碼
??由于沒有開源,我按照自己理解進(jìn)行代碼實(shí)現(xiàn)。根據(jù)文中公式,該算法只涉及到參數(shù)更新過程,因此只需要在pytorch中重寫SGD即可實(shí)現(xiàn)本算法中說的關(guān)鍵/非關(guān)鍵參數(shù)分別更新;然后在訓(xùn)練的時(shí)候加入早停即可。
??重寫的newSGD代碼如下,主要是增加了tau和decay1兩個(gè)參數(shù)。tau就是文中τ\tauτ噪聲率,注意式(6)和式(5)的區(qū)別,對于非關(guān)鍵參數(shù),就是把梯度項(xiàng)置零,只有正則化項(xiàng)了,所以代碼可以非常簡潔的寫出來。在SGD中,weight_decay就是正則化項(xiàng),但是torch1.6給出的SGD用的是l2正則,而論文中給出的公式用的是l1正則,所以我又新加了一個(gè)weight_decay1用來實(shí)現(xiàn)l1正則。
然后在訓(xùn)練時(shí)把原來的SGD替換即可
from newSGD import newSGD optimizer = newSGD(net.parameters(), lr=0.01,momentum=0.9, tau=0.2, weight_decay1=1e-3)2,我的試驗(yàn)
??為了加快速度,試驗(yàn)主要在MNIST數(shù)據(jù)集和LeNet上進(jìn)行,個(gè)別補(bǔ)充進(jìn)行了CIFAR10上的ResNet18試驗(yàn)。試驗(yàn)參數(shù)配置:epoch = 100, BatchSize = 128, lr=0.01 ,momentum = 0.9, weight_decay = 0.001。由于L1正則不便于觀察規(guī)律(原因見2.4.2節(jié)介紹),下面試驗(yàn)使用L2正則。噪聲數(shù)據(jù)只使用同步噪聲標(biāo)簽,即每個(gè)類別按照噪聲率抽取樣本隨機(jī)變換為任意其他類別的標(biāo)簽。注意噪聲只存在于訓(xùn)練集,測試集不含噪聲,是干凈的。
2.1,不同噪聲率下觀察“早停”的作用
??神經(jīng)網(wǎng)絡(luò)在訓(xùn)練早期只學(xué)習(xí)干凈標(biāo)簽,在訓(xùn)練的后期才逐漸學(xué)習(xí)噪聲標(biāo)簽,因此可以用早停法抑制噪聲標(biāo)簽。我們先觀察這個(gè)現(xiàn)象,試驗(yàn)中不使用本文提到的新算法,只使用LeNet和交叉熵?fù)p失:
??從圖中可以看出幾個(gè)特點(diǎn):
(1)隨著噪聲率的增加,訓(xùn)練集訓(xùn)練精度明顯降低,但測試集仍能達(dá)到較高的精度,例如即使噪聲含量80%時(shí),此時(shí)訓(xùn)練集精度不足35%,但測試集精度最高仍可達(dá)到85%以上。這說明神經(jīng)網(wǎng)絡(luò)本身就對噪聲有一定的魯棒性。
(2)含噪聲時(shí),網(wǎng)絡(luò)早期先學(xué)習(xí)干凈數(shù)據(jù),所以測試集仍可以達(dá)到很高精度,但后期開始記憶噪聲數(shù)據(jù),導(dǎo)致測試集精度下降。所以早停肯定可以起到抑制噪聲標(biāo)簽的作用。
(3)對比噪聲含量80%和90%的訓(xùn)練精度曲線(圖中淺藍(lán)和深藍(lán)虛線),我們發(fā)現(xiàn)一個(gè)有意思的地方,90%噪聲的訓(xùn)練精度后期比80%的還高。我的解釋是:由于數(shù)據(jù)集就10個(gè)類別,90%噪聲時(shí)幾乎等于完全隨機(jī),網(wǎng)絡(luò)從一開始就意識到這沒有任何規(guī)律可以找,干脆就快速發(fā)展記憶數(shù)據(jù)能力了。這很有意思,值得繼續(xù)思考。
2.2,不同τ\tauτ參數(shù)下觀察“彩票假說”現(xiàn)象
??彩票假說指出神經(jīng)網(wǎng)絡(luò)只有少部分參數(shù)真正發(fā)揮作用。上面newSGD算法中給出的τ\tauτ會使得網(wǎng)絡(luò)中每個(gè)參數(shù)張量中都有占比例為τ\tauτ的參數(shù)在經(jīng)過充分訓(xùn)練后趨于0,因此使用這個(gè)代碼就可以觀察到彩票假說現(xiàn)象。我們使用不含噪聲的數(shù)據(jù)來觀察這個(gè)現(xiàn)象:
從圖中可以看出,神經(jīng)網(wǎng)絡(luò)具有驚人的參數(shù)壓縮潛力,τ=0.995\tau=0.995τ=0.995時(shí),相當(dāng)于只有0.5%的參數(shù)起作用,測試精度仍可達(dá)到95%以上。τ=0.999\tau=0.999τ=0.999時(shí),訓(xùn)練結(jié)束后,我們把其中conv2層的權(quán)重絕對值reshape到25×96以及fc1層的權(quán)重絕對值進(jìn)行可視化,畫出來如下圖。可見其中確實(shí)只有極少的參數(shù)存在了,但即使這么稀疏的參數(shù),仍然可以達(dá)到70%以上的精度。τ=0.9999\tau=0.9999τ=0.9999時(shí),網(wǎng)絡(luò)的效果才有明顯的下降,但仍有接近40%的精度。
2.3,不同噪聲率和不同τ\tauτ參數(shù)下觀察本文算法去噪效果
又在CIFAR10上用ResNet18做了部分試驗(yàn),效果和上圖類似:
??從圖中可以看出:
??τ=0\tau=0τ=0就是論文Table1中的CE,使用本算法之后,τ\tauτ較大時(shí)起到的作用只是隨著訓(xùn)練的繼續(xù),測試精度下降變少,但考慮到早停時(shí),最佳精度發(fā)生在初期,使用本方法后和CE并無明顯優(yōu)勢。這可能是MNIST數(shù)據(jù)集過于簡單,加的噪聲模式也比較簡單,所以看不出論文算法的優(yōu)勢。這個(gè)和論文中的Table1也是一致的。
2.4,算法局部修改試驗(yàn)
??對算法中的衰減系數(shù)(1?τ1-\tau1?τ),l1正則,劃分關(guān)鍵參數(shù)的判據(jù)gig_{i}gi?等的作用和必要性仍不太理解,因此我們從試驗(yàn)對比中觀察它們的效果。
2.4.1 (1?τ1-\tau1?τ)
對于式(5)中的(1?τ1-\tau1?τ)項(xiàng),在原本的SGD公式中是沒有的,作者說這里增加此項(xiàng)能夠抑制過度自信下降的作用,下圖以20%噪聲率為例,對比了使用(1?τ1-\tau1?τ)和不使用(1?τ1-\tau1?τ)的效果。
從圖中可以看出,當(dāng)τ\tauτ=0.8或0.9時(shí),(1?τ1-\tau1?τ)項(xiàng)能夠起到一定的正則效果,會避免訓(xùn)練的后期記憶噪聲數(shù)據(jù),但效果并不明顯。
2.4.2 L1正則
下圖給出L1正則和L2正則在20%噪聲率時(shí)的測試集精度曲線,可以看出L1正則的正則化效果更重,即使τ\tauτ較小時(shí)也可以防止模型后期記憶噪聲數(shù)據(jù)。但是L1正則在模型初期的精度表現(xiàn)不如L2正則,也就是說如果使用早停的話其效果不如L2。由于L1正則過強(qiáng)的正則化效果,不便于觀察2.1,2.2節(jié)中的現(xiàn)象,所以前序試驗(yàn)都使用L2正則進(jìn)行。
2.4.3 gig_{i}gi?
??gig_{i}gi?是劃分關(guān)鍵和非關(guān)鍵參數(shù)的依據(jù),作者在公式(3)中給出的計(jì)算方法是參數(shù)的梯度和參數(shù)的點(diǎn)積的絕對值。作者的推導(dǎo)過程我沒有看懂(數(shù)學(xué)太菜了!),但我可以用試驗(yàn)檢驗(yàn)以下這個(gè)表達(dá)式的充分必要性,也就是
- 使用式(3)能否把參數(shù)壓縮到少量關(guān)鍵參數(shù);
- 使用式(3)確定的關(guān)鍵參數(shù)是否真的關(guān)鍵,即是否能以少量關(guān)鍵參數(shù)仍達(dá)到和全量參數(shù)接近的精度;
??文中公式(3)我在代碼中寫成 g = (d_p * p).abs(),我又嘗試了其他幾種劃分關(guān)鍵和非關(guān)鍵參數(shù)的方法,
??方法B:g = d_p.abs() + p.abs()
??方法C:提前隨機(jī)選定每個(gè)參數(shù)張量中占比τ\tauτ的位置制成mask,然后每輪參數(shù)更新時(shí),這些位置對應(yīng)的參數(shù)的梯度置0。
??我們定義絕對值大于0.001的參數(shù)為有效參數(shù),上圖的第一行三個(gè)圖表示的是隨著訓(xùn)練輪數(shù),網(wǎng)絡(luò)中的總有效參數(shù)量的變化情況,第二行三個(gè)圖表示隨著訓(xùn)練輪數(shù),測試集精度的變化。
??從上面圖中對比我們可以看出,對于本文方法(最左圖),在不同的τ\tauτ下都能使有效參數(shù)量逐漸收縮到占比總參數(shù)量約為τ\tauτ的位置處,并且精度仍能夠有著不錯(cuò)的保持。而對于另外兩種方法,它們不能夠保持有效參數(shù)不再壓縮,而是會出現(xiàn)參數(shù)量不斷的下降,精度也掉的一塌糊涂,說明這兩種方法不能有效區(qū)分關(guān)鍵參數(shù)和非關(guān)鍵參數(shù),也就不能夠在訓(xùn)練后期把關(guān)鍵參數(shù)穩(wěn)定住。實(shí)際上我還嘗試了很多其他的參數(shù)劃分方法,都沒有文中方法有效。
??所以說文中式(3)給出的關(guān)鍵參數(shù)劃分判據(jù)是非常有效的,對公式的推導(dǎo)過程后續(xù)再慢慢吃透。
??(補(bǔ)充說明,第一行圖中可以明顯觀察到有效參數(shù)量每次都是在75epoch和95epoch處有明顯轉(zhuǎn)折,這個(gè)原因是網(wǎng)絡(luò)使用的默認(rèn)的標(biāo)準(zhǔn)參數(shù)初始化方式,參數(shù)的分布概率是固定的,而同樣的weight_decay下參數(shù)的收縮速率也是固定的,所以會有同批的參數(shù)被同時(shí)收縮到0.001以下。)
五、讀后感
??本文提出的方法實(shí)際上主要是從彩票假說和神經(jīng)網(wǎng)絡(luò)早期學(xué)習(xí)干凈標(biāo)簽這兩點(diǎn)出發(fā),本文方法的噪聲標(biāo)簽抑制能力實(shí)際上達(dá)不到SOTA。但彩票假說中只是指出了神經(jīng)網(wǎng)絡(luò)中真正關(guān)鍵的參數(shù)很少,卻也沒有指出有效的提取關(guān)鍵參數(shù)的方法,而本文提出的劃分關(guān)鍵參數(shù)的方法非常有意思,有可能提供一種新的模型壓縮的思路。這篇論文的寫作也非常好,值得學(xué)習(xí)。
<補(bǔ)充 2021-02-09>更具tau修正梯度的核心部分代碼修改如下,能夠進(jìn)一步提高精度,加快運(yùn)算速度。 m = p.numel()if tau != 0 and m>1000:g = (d_p * p).abs()if m>10000:gf = g.flatten()[:10000]mn = int(10000*(1-100/math.sqrt(m)*(1-tau)))if mn > 9990:mn = 9990kth,_ = gf.kthvalue(mn)else:mn = int(p.numel()*tau)kth,_ = g.flatten().kthvalue(mn)d_p = torch.where(g < kth, torch.zeros_like(d_p), d_p)
總結(jié)
以上是生活随笔為你收集整理的【论文学习】ICLR2021,鲁棒早期学习法:抑制记忆噪声标签ROBUST EARLY-LEARNING: HINDERING THE MEMORIZATION OF NOISY LABELS的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch CookBook
- 下一篇: LeNet试验(五)观察“彩票假说”现象