scheduled sampling_seq2seq
來源:NIPS 2015
本文介紹了decode時采樣的一種新方法,稱為“curriculum learning”(課程學習),對應的采樣方法叫做“scheduled sampling”(計劃采樣)。
-
傳統方法的問題:傳統的神經網絡訓練時和預測時的輸入不一樣。
比如對于上圖這種網絡結構,訓練時上一步的y是真實序列標記,然后做為輸入到下一步,而預測時上一步的y是模型的輸出,然后再做為下一步的輸入。這種訓練和預測時輸入的差異導致了一個問題:當在某一步做出一個錯誤選擇后,后面可能會產生累積錯誤。因為訓練和預測時的前一輸入的選擇不同,導致可能會出現預測時產生的序列在訓練過程中從沒有出現過,導致預測時模型不知道如何選擇。基于此,作者提出盡量讓訓練和預測過程一致的“課程學習” -
“課程學習”和”計劃采樣“
課程學習如上圖所示,訓練時網絡將不再完全采用真實序列標記做為下一步的輸入,而是以一個概率p選擇真實標記,以1-p選擇模型自身的輸出。“計劃采樣”即p的大小在訓練過程中是變化的,就像學習率一樣。作者的思想是:一開始網絡訓練不充分,那么p盡量選大值,即盡量使用真實標記。然后隨著訓練的進行,模型訓練越來越充分,這時p也要減小,即盡量選擇模型自己的輸出。這樣就盡量使模型訓練和預測保持一致。
p隨訓練次數的變化方式有如下選擇:
-
實驗
本文提出的想法在image captioning,Constituency Parsing,speech Recognition等任務上較之前的成果取得了一定improve。
paper地址:http://papers.nips.cc/paper/5956-scheduled-sampling-for-sequence-prediction-with-recurrent-neural-networks.pdf
總結
以上是生活随笔為你收集整理的scheduled sampling_seq2seq的全部內容,希望文章能夠幫你解決所遇到的問題。