详解seq2seq
1. 什么是seq2seq
在?然語?處理的很多應?中,輸?和輸出都可以是不定?序列。以機器翻譯為例,輸?可以是?段不定?的英語?本序列,輸出可以是?段不定?的法語?本序列,例如:
英語輸?:“They”、“are”、“watching”、“.”
法語輸出:“Ils”、“regardent”、“.”
當輸?和輸出都是不定?序列時,我們可以使?編碼器—解碼器(encoder-decoder)或者seq2seq模型。序列到序列模型,簡稱seq2seq模型。這兩個模型本質上都?到了兩個循環神經?絡,分別叫做編碼器和解碼器。編碼器?來分析輸?序列,解碼器?來?成輸出序列。兩 個循環神經網絡是共同訓練的。
下圖描述了使?編碼器—解碼器將上述英語句?翻譯成法語句?的?種?法。在訓練數據集中,我們可以在每個句?后附上特殊符號“”(end of sequence)以表?序列的終?。編碼器每個時間步的輸?依次為英語句?中的單詞、標點和特殊符號“”。下圖中使?了編碼器在 最終時間步的隱藏狀態作為輸?句?的表征或編碼信息。解碼器在各個時間步中使?輸?句?的 編碼信息和上個時間步的輸出以及隱藏狀態作為輸?。我們希望解碼器在各個時間步能正確依次 輸出翻譯后的法語單詞、標點和特殊符號“”。需要注意的是,解碼器在最初時間步的輸? ?到了?個表?序列開始的特殊符號“”(beginning of sequence)。
2. 編碼器
編碼器的作?是把?個不定?的輸?序列變換成?個定?的背景變量 c,并在該背景變量中編碼輸?序列信息。常?的編碼器是循環神經?絡。
讓我們考慮批量?小為1的時序數據樣本。假設輸?序列是 x1, . . . , xT,例如 xi 是輸?句?中的第 i 個詞。在時間步 t,循環神經?絡將輸? xt 的特征向量 xt 和上個時間步的隱藏狀態ht?1ht?1變換為當前時間步的隱藏狀態ht。我們可以?函數 f 表達循環神經?絡隱藏層的變換:
ht=f(xt,ht?1)ht=f(xt,ht?1)
接下來,編碼器通過?定義函數 q 將各個時間步的隱藏狀態變換為背景變量:
c=q(h1,…,hT)c=q(h1,…,hT)
例如,當選擇 q(***h*1, . . . , ***h***T ) = ***h***T 時,背景變量是輸?序列最終時間步的隱藏狀態***h***T。
以上描述的編碼器是?個單向的循環神經?絡,每個時間步的隱藏狀態只取決于該時間步及之前的輸??序列。我們也可以使?雙向循環神經?絡構造編碼器。在這種情況下,編碼器每個時間步的隱藏狀態同時取決于該時間步之前和之后的?序列(包括當前時間步的輸?),并編碼了整個序列的信息。
3. 解碼器
剛剛已經介紹,編碼器輸出的背景變量 c 編碼了整個輸?序列 x1, . . . , xT 的信息。給定訓練樣本中的輸出序列 y1, y2, . . . , yT′ ,對每個時間步 t′(符號與輸?序列或編碼器的時間步 t 有區別),解碼器輸出 yt′ 的條件概率將基于之前的輸出序列 y1,…,yt′?1y1,…,yt′?1 和背景變量 c,即:
P(yt′|y1,…,yt′?1,c)P(yt′|y1,…,yt′?1,c)
為此,我們可以使?另?個循環神經?絡作為解碼器。在輸出序列的時間步 t′,解碼器將上?時間步的輸出 yt′?1yt′?1 以及背景變量 c 作為輸?,并將它們與上?時間步的隱藏狀態 st′?1st′?1 變換為當前時間步的隱藏狀態st′。因此,我們可以?函數 g 表達解碼器隱藏層的變換:
st′=g(yt′?1,c,st′?1)st′=g(yt′?1,c,st′?1)
有了解碼器的隱藏狀態后,我們可以使??定義的輸出層和softmax運算來計算P(yt′|y1,…,yt′?1,c)P(yt′|y1,…,yt′?1,c),例如,基于當前時間步的解碼器隱藏狀態 st′、上?時間步的輸出st′?1st′?1以及背景變量 c 來計算當前時間步輸出 yt′ 的概率分布。
4. 訓練模型
根據最?似然估計,我們可以最?化輸出序列基于輸?序列的條件概率:
P(y1,…,yt′?1|x1,…,xT)=T′∏t′=1P(yt′|y1,…,yt′?1,x1,…,xT)P(y1,…,yt′?1|x1,…,xT)=∏t′=1T′P(yt′|y1,…,yt′?1,x1,…,xT)
=T′∏t′=1P(yt′|y1,…,yt′?1,c)=∏t′=1T′P(yt′|y1,…,yt′?1,c)
并得到該輸出序列的損失:
?logP(y1,…,yt′?1|x1,…,xT)=?T′∑t′=1logP(yt′|y1,…,yt′?1,c)?logP(y1,…,yt′?1|x1,…,xT)=?∑t′=1T′logP(yt′|y1,…,yt′?1,c)
在模型訓練中,所有輸出序列損失的均值通常作為需要最小化的損失函數。在上圖所描述的模型預測中,我們需要將解碼器在上?個時間步的輸出作為當前時間步的輸?。與此不同,在訓練中我們也可以將標簽序列(訓練集的真實輸出序列)在上?個時間步的標簽作為解碼器在當前時間步的輸?。這叫作強制教學(teacher forcing)。
5. seq2seq模型預測
以上介紹了如何訓練輸?和輸出均為不定?序列的編碼器—解碼器。本節我們介紹如何使?編碼器—解碼器來預測不定?的序列。
在準備訓練數據集時,我們通常會在樣本的輸?序列和輸出序列后面分別附上?個特殊符號“”表?序列的終?。我們在接下來的討論中也將沿?上?節的全部數學符號。為了便于討論,假設解碼器的輸出是?段?本序列。設輸出?本詞典Y(包含特殊符號“”)的?小為|Y|,輸出序列的最??度為T′。所有可能的輸出序列?共有 O(|y|T′)O(|y|T′) 種。這些輸出序列中所有特殊符號“”后?的?序列將被舍棄。
5.1 貪婪搜索
貪婪搜索(greedy search)。對于輸出序列任?時間步t′,我們從|Y|個詞中搜索出條件概率最?的詞:
yt′=argmaxy∈YP(y|y1,…,yt′?1,c)yt′=argmaxy∈YP(y|y1,…,yt′?1,c)
作為輸出。?旦搜索出“”符號,或者輸出序列?度已經達到了最??度T′,便完成輸出。我們在描述解碼器時提到,基于輸?序列?成輸出序列的條件概率是∏T′t′=1P(yt′|y1,…,yt′?1,c)∏t′=1T′P(yt′|y1,…,yt′?1,c)。我們將該條件概率最?的輸出序列稱為最優輸出序列。而貪婪搜索的主要問題是不能保證得到最優輸出序列。
下?來看?個例?。假設輸出詞典??有“A”“B”“C”和“”這4個詞。下圖中每個時間步
下的4個數字分別代表了該時間步?成“A”“B”“C”和“”這4個詞的條件概率。在每個時間步,貪婪搜索選取條件概率最?的詞。因此,圖10.9中將?成輸出序列“A”“B”“C”“”。該輸出序列的條件概率是0.5 × 0.4 × 0.4 × 0.6 = 0.048。
接下來,觀察下面演?的例?。與上圖中不同,在時間步2中選取了條件概率第??的詞“C”
。由于時間步3所基于的時間步1和2的輸出?序列由上圖中的“A”“B”變為了下圖中的“A”“C”,下圖中時間步3?成各個詞的條件概率發?了變化。我們選取條件概率最?的詞“B”。此時時間步4所基于的前3個時間步的輸出?序列為“A”“C”“B”,與上圖中的“A”“B”“C”不同。因此,下圖中時間步4?成各個詞的條件概率也與上圖中的不同。我們發現,此時的輸出序列“A”“C”“B”“”的條件概率是0.5 × 0.3 × 0.6 × 0.6 = 0.054,?于貪婪搜索得到的輸出序列的條件概率。因此,貪婪搜索得到的輸出序列“A”“B”“C”“”并?最優輸出序列。
5.2 窮舉搜索
如果?標是得到最優輸出序列,我們可以考慮窮舉搜索(exhaustive search):窮舉所有可能的輸出序列,輸出條件概率最?的序列。
雖然窮舉搜索可以得到最優輸出序列,但它的計算開銷 O(|y|T′)O(|y|T′) 很容易過?。例如,當|Y| =
10000且T′ = 10時,我們將評估 1000010=10401000010=1040 個序列:這?乎不可能完成。而貪婪搜索的計
算開銷是 O(|y|T′)O(|y|T′),通常顯著小于窮舉搜索的計算開銷。例如,當|Y| = 10000且T′ = 10時,我
們只需評估 10000?10=10510000?10=105 個序列。
5.3 束搜索
束搜索(beam search)是對貪婪搜索的?個改進算法。它有?個束寬(beam size)超參數。我們將它設為 k。在時間步 1 時,選取當前時間步條件概率最?的 k 個詞,分別組成 k 個候選輸出序列的?詞。在之后的每個時間步,基于上個時間步的 k 個候選輸出序列,從 k |Y| 個可能的輸出序列中選取條件概率最?的 k 個,作為該時間步的候選輸出序列。最終,我們從各個時間步的候選輸出序列中篩選出包含特殊符號“”的序列,并將它們中所有特殊符號“”后?的?序列舍棄,得到最終候選輸出序列的集合。
束寬為2,輸出序列最??度為3。候選輸出序列有A、C、AB、CE、ABD和CED。我們將根據這6個序列得出最終候選輸出序列的集合。在最終候選輸出序列的集合中,我們取以下分數最?的序列作為輸出序列:
1LαlogP(y1,…,yL)=1LαT′∑t′=1logP(yt′|y1,…,yt′?1,c)1LαlogP(y1,…,yL)=1Lα∑t′=1T′logP(yt′|y1,…,yt′?1,c)
其中 L 為最終候選序列?度,α ?般可選為0.75。分?上的 Lα 是為了懲罰較?序列在以上分數中較多的對數相加項。分析可知,束搜索的計算開銷為 O(k|y|T′)O(k|y|T′)。這介于貪婪搜索和窮舉搜索的計算開銷之間。此外,貪婪搜索可看作是束寬為 1 的束搜索。束搜索通過靈活的束寬 k 來權衡計算開銷和搜索質量。
6. Bleu得分
評價機器翻譯結果通常使?BLEU(Bilingual Evaluation Understudy)(雙語評估替補)。對于模型預測序列中任意的?序列,BLEU考察這個?序列是否出現在標簽序列中。
具體來說,設詞數為 n 的?序列的精度為 pn。它是預測序列與標簽序列匹配詞數為 n 的?序列的數量與預測序列中詞數為 n 的?序列的數量之?。舉個例?,假設標簽序列為A、B、C、D、E、F,預測序列為A、B、B、C、D,那么:
P1=預測序列中的1元詞組在標簽序列是否存在的個數預測序列1元詞組的個數之和P1=預測序列中的1元詞組在標簽序列是否存在的個數預測序列1元詞組的個數之和
預測序列一元詞組:A/B/C/D,都在標簽序列里存在,所以P1=4/5,以此類推,p2 = 3/4, p3 = 1/3, p4 = 0。設 lenlabel和lenpredlenlabel和lenpred 分別為標簽序列和預測序列的詞數,那么,BLEU的定義為:
exp(min(0,1?lenlabellenpred))k∏n=1p12nnexp(min(0,1?lenlabellenpred))∏n=1kpn12n
其中 k 是我們希望匹配的?序列的最?詞數。可以看到當預測序列和標簽序列完全?致時,
BLEU為1。
因為匹配較??序列?匹配較短?序列更難,BLEU對匹配較??序列的精度賦予了更?權重。例如,當 pn 固定在0.5時,隨著n的增?,0.512≈0.7,0.514≈0.84,0.518≈0.92,0.5116≈0.960.512≈0.7,0.514≈0.84,0.518≈0.92,0.5116≈0.96。另外,模型預測較短序列往往會得到較?pn 值。因此,上式中連乘項前?的系數是為了懲罰較短的輸出而設的。舉個例?,當k = 2時,假設標簽序列為A、B、C、D、E、F,而預測序列為A、 B。雖然p1 = p2 = 1,但懲罰系數exp(1-6/2) ≈ 0.14,因此BLEU也接近0.14。
總結
- 上一篇: 从此以后谁也别说我不懂LDO了
- 下一篇: 5+App下Mui框架开发仿拼多多App