LSTM神经网络详解
LSTM
長短時(shí)記憶網(wǎng)絡(luò)(Long Short Term Memory Network, LSTM),是一種改進(jìn)之后的循環(huán)神經(jīng)網(wǎng)絡(luò),可以解決RNN無法處理長距離的依賴的問題,目前比較流行。
長短時(shí)記憶網(wǎng)絡(luò)的思路:
原始 RNN 的隱藏層只有一個(gè)狀態(tài),即h,它對于短期的輸入非常敏感。
再增加一個(gè)狀態(tài),即c,讓它來保存長期的狀態(tài),稱為單元狀態(tài)(cell state)。
按照上面時(shí)間維度展開:
在 t 時(shí)刻,LSTM 的輸入有三個(gè):當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值 X_t、上一時(shí)刻 LSTM 的輸出值 h_t-1、以及上一時(shí)刻的單元狀態(tài) C_t-1;
LSTM 的輸出有兩個(gè):當(dāng)前時(shí)刻 LSTM 輸出值 h_t、和當(dāng)前時(shí)刻的單元狀態(tài) C_t.
關(guān)鍵問題是:怎樣控制長期狀態(tài) c ?
方法是使用三個(gè)開關(guān)
第一個(gè)開關(guān),負(fù)責(zé)控制繼續(xù)保存長期狀態(tài)c;
第二個(gè)開關(guān),負(fù)責(zé)控制把即時(shí)狀態(tài)輸入到長期狀態(tài)c;
第三個(gè)開關(guān),負(fù)責(zé)控制是否把長期狀態(tài)c作為當(dāng)前的LSTM的輸出。
如何在算法中實(shí)現(xiàn)這三個(gè)開關(guān)?
方法:用 門(gate)
定義:gate 實(shí)際上就是一層全連接層,輸入是一個(gè)向量,輸出是一個(gè) 0到1 之間的實(shí)數(shù)向量。
公式為:
也就是:
gate 如何進(jìn)行控制?
方法:用門的輸出向量按元素乘以我們需要控制的那個(gè)向量
原理:門的輸出是 0到1 之間的實(shí)數(shù)向量,當(dāng)門輸出為 0 時(shí),任何向量與之相乘都會(huì)得到 0 向量,這就相當(dāng)于什么都不能通過;
輸出為 1 時(shí),任何向量與之相乘都不會(huì)有任何改變,這就相當(dāng)于什么都可以通過。
LSTM 的前向計(jì)算
一共有 6 個(gè)公式
遺忘門(forget gate)
它決定了上一時(shí)刻的單元狀態(tài) c_t-1 有多少保留到當(dāng)前時(shí)刻 c_t
輸入門(input gate)
它決定了當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入 x_t 有多少保存到單元狀態(tài) c_t
輸出門(output gate)
控制單元狀態(tài) c_t 有多少輸出到 LSTM 的當(dāng)前輸出值 h_t
遺忘門的計(jì)算為:
遺忘門的計(jì)算公式中:
W_f 是遺忘門的權(quán)重矩陣,[h_t-1, x_t] 表示把兩個(gè)向量連接成一個(gè)更長的向量,b_f是遺忘門的偏置項(xiàng),σ 是 sigmoid 函數(shù)。
輸入門的計(jì)算公式:
根據(jù)上一次的輸出和本次輸入來計(jì)算當(dāng)前輸入的單元狀態(tài):
當(dāng)前時(shí)刻的單元狀態(tài) c_t 的計(jì)算:由上一次的單元狀態(tài) c_t-1 按元素乘以遺忘門 f_t,再用當(dāng)前輸入的單元狀態(tài) c_t 按元素乘以輸入門 i_t,再將兩個(gè)積加和:這樣,就可以把當(dāng)前的記憶 c_t 和長期的記憶 c_t-1 組合在一起,形成了新的單元狀態(tài) c_t。由于遺忘門的控制,它可以保存很久很久之前的信息,由于輸入門的控制,它又可以避免當(dāng)前無關(guān)緊要的內(nèi)容進(jìn)入記憶。
輸出門的計(jì)算公式:
LSTM 的反向傳播訓(xùn)練算法
主要有三步:
1. 前向計(jì)算每個(gè)神經(jīng)元的輸出值,一共有 5 個(gè)變量,計(jì)算方法就是前一部分:
2. 反向計(jì)算每個(gè)神經(jīng)元的誤差項(xiàng)值。與 RNN 一樣,LSTM 誤差項(xiàng)的反向傳播也是包括兩個(gè)方向:
一個(gè)是沿時(shí)間的反向傳播,即從當(dāng)前 t 時(shí)刻開始,計(jì)算每個(gè)時(shí)刻的誤差項(xiàng);
一個(gè)是將誤差項(xiàng)向上一層傳播。
3. 根據(jù)相應(yīng)的誤差項(xiàng),計(jì)算每個(gè)權(quán)重的梯度。
gate 的激活函數(shù)定義為 sigmoid 函數(shù),輸出的激活函數(shù)為 tanh 函數(shù),導(dǎo)數(shù)分別為:
具體的推導(dǎo)公式為:
具體的推導(dǎo)公式為:
目標(biāo)是要學(xué)習(xí) 8 組參數(shù),如下圖所示:
又權(quán)重矩陣 W 都是由兩個(gè)矩陣拼接而成,這兩部分在反向傳播中使用不同的公式,因此在后續(xù)的推導(dǎo)中,權(quán)重矩陣也要被寫為分開的兩個(gè)矩陣。
接著就來求兩個(gè)方向的誤差,和一個(gè)梯度計(jì)算。
1.誤差項(xiàng)沿時(shí)間的反向傳遞:
定義 t 時(shí)刻的誤差項(xiàng):
目的是要計(jì)算出 t-1 時(shí)刻的誤差項(xiàng):
利用 h_t c_t 的定義,和全導(dǎo)數(shù)公式,可以得到 將誤差項(xiàng)向前傳遞到任意k時(shí)刻的公式:
2. 將誤差項(xiàng)傳遞到上一層的公式:
3. 權(quán)重梯度的計(jì)算:
以上就是 LSTM 的訓(xùn)練算法的全部公式。
總結(jié)
以上是生活随笔為你收集整理的LSTM神经网络详解的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Solr原理】Leader Shard
- 下一篇: html5语异性元素,异性的5句性暗示