貌离神合的RNN与ODE:花式RNN简介
作者丨蘇劍林
單位丨廣州火焰信息科技有限公司
研究方向丨NLP,神經(jīng)網(wǎng)絡(luò)
個(gè)人主頁丨kexue.fm
本來筆者已經(jīng)決心不玩 RNN 了,但是在上個(gè)星期思考時(shí)忽然意識(shí)到 RNN 實(shí)際上對應(yīng)了 ODE(常微分方程)的數(shù)值解法,這為我一直以來想做的事情——用深度學(xué)習(xí)來解決一些純數(shù)學(xué)問題——提供了思路。事實(shí)上這是一個(gè)頗為有趣和有用的結(jié)果,遂介紹一翻。順便地,本文也涉及到了自己動(dòng)手編寫 RNN 的內(nèi)容,所以本文也可以作為編寫自定義的 RNN 層的一個(gè)簡單教程。
注:本文并非前段時(shí)間的熱點(diǎn)“神經(jīng) ODE [1]”的介紹(但有一定的聯(lián)系)。
RNN基本
什么是RNN??
眾所周知,RNN 是“循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)”,跟 CNN 不同,RNN 可以說是一類模型的總稱,而并非單個(gè)模型。簡單來講,只要是輸入向量序列 (x1,x2,…,xT),輸出另外一個(gè)向量序列 (y1,y2,…,yT),并且滿足如下遞歸關(guān)系的模型,都可以稱為 RNN。
也正因?yàn)槿绱?#xff0c;原始的樸素 RNN,還有改進(jìn)的如 GRU、LSTM、SRU 等模型,我們都稱為 RNN,因?yàn)樗鼈兌伎梢宰鳛樯鲜降囊粋€(gè)特例。還有一些看上去與 RNN 沒關(guān)的內(nèi)容,比如前不久介紹的 CRF 的分母的計(jì)算,實(shí)際上也是一個(gè)簡單的 RNN。
說白了,RNN 其實(shí)就是遞歸計(jì)算。
自己編寫RNN
這里我們先介紹如何用 Keras 簡單快捷地編寫自定義的 RNN。?
事實(shí)上,不管在 Keras 還是純 tensorflow 中,要自定義自己的 RNN 都不算復(fù)雜。在 Keras 中,只要寫出每一步的遞歸函數(shù);而在 tensorflow 中,則稍微復(fù)雜一點(diǎn),需要將每一步的遞歸函數(shù)封裝為一個(gè) RNNCell 類。
下面介紹用 Keras 實(shí)現(xiàn)最基本的一個(gè) RNN:
代碼非常簡單:
from?keras.layers?import?Layer
import?keras.backend?as?K
class?My_RNN(Layer):
????def?__init__(self,?output_dim,?**kwargs):
????????self.output_dim?=?output_dim?#?輸出維度
????????super(My_RNN,?self).__init__(**kwargs)
????def?build(self,?input_shape):?#?定義可訓(xùn)練參數(shù)
????????self.kernel1?=?self.add_weight(name='kernel1',
??????????????????????????????????????shape=(self.output_dim,?self.output_dim),
??????????????????????????????????????initializer='glorot_normal',
??????????????????????????????????????trainable=True)
????????self.kernel2?=?self.add_weight(name='kernel2',
??????????????????????????????????????shape=(input_shape[-1],?self.output_dim),
??????????????????????????????????????initializer='glorot_normal',
??????????????????????????????????????trainable=True)
????????self.bias?=?self.add_weight(name='kernel',
??????????????????????????????????????shape=(self.output_dim,),
??????????????????????????????????????initializer='glorot_normal',
??????????????????????????????????????trainable=True)
????def?step_do(self,?step_in,?states):?#?定義每一步的迭代
????????step_out?=?K.tanh(K.dot(states[0],?self.kernel1)?+
??????????????????????????K.dot(step_in,?self.kernel2)?+
??????????????????????????self.bias)
????????return?step_out,?[step_out]
????def?call(self,?inputs):?#?定義正式執(zhí)行的函數(shù)
????????init_states?=?[K.zeros((K.shape(inputs)[0],
????????????????????????????????self.output_dim)
??????????????????????????????)]?#?定義初始態(tài)(全零)
????????outputs?=?K.rnn(self.step_do,?inputs,?init_states)?#?循環(huán)執(zhí)行step_do函數(shù)
????????return?outputs[0]?#?outputs是一個(gè)tuple,outputs[0]為最后時(shí)刻的輸出,
??????????????????????????#?outputs[1]為整個(gè)輸出的時(shí)間序列,output[2]是一個(gè)list,
??????????????????????????#?是中間的隱藏狀態(tài)。
????def?compute_output_shape(self,?input_shape):
????????return?(input_shape[0],?self.output_dim)
可以看到,雖然代碼行數(shù)不少,但大部分都只是固定格式的語句,真正定義 RNN 的,是 step_do 這個(gè)函數(shù),這個(gè)函數(shù)接受兩個(gè)輸入:step_in 和 states。其中 step_in 是一個(gè) (batch_size, input_dim) 的張量,代表當(dāng)前時(shí)刻的樣本 xt,而 states 是一個(gè) list,代表 yt?1 及一些中間變量。
特別要注意的是,states 是一個(gè)張量的 list,而不是單個(gè)張量,這是因?yàn)樵谶f歸過程中可能要同時(shí)傳遞多個(gè)中間變量,而不僅僅是 yt?1 一個(gè),比如 LSTM 就需要有兩個(gè)態(tài)張量。最后 step_do 要返回 yt 和新的 states,這是 step_do 這步的函數(shù)的編寫規(guī)范。?
而 K.rnn 這個(gè)函數(shù),接受三個(gè)基本參數(shù)(還有其他參數(shù),請自行看官方文檔),其中第一個(gè)參數(shù)就是剛才寫好的 step_do 函數(shù),第二個(gè)參數(shù)則是輸入的時(shí)間序列,第三個(gè)是初始態(tài),跟前面說的 states 一致,所以很自然 init_states 也是一個(gè)張量的 list,默認(rèn)情況下我們會(huì)選擇全零初始化。
ODE基本
什么是ODE?
ODE 就是“常微分方程(Ordinary Differential Equation)”,這里指的是一般的常微分方程組:
研究 ODE 的領(lǐng)域通常也直接稱為“動(dòng)力學(xué)”、“動(dòng)力系統(tǒng)”,這是因?yàn)榕nD力學(xué)通常也就只是一組 ODE 而已。
ODE可以產(chǎn)生非常豐富的函數(shù)。比如 e^t 其實(shí)就是 x˙=x 的解,sint 和 cost 都是 x¨+x=0 的解(初始條件不同)。事實(shí)上,我記得確實(shí)有一些教程是直接通過微分方程 x˙=x 來定義 e^t 函數(shù)的。除了這些初等函數(shù),很多我們能叫得上名字但不知道是什么鬼的特殊函數(shù),都是通過 ODE 導(dǎo)出來的,比如超幾何函數(shù)、勒讓德函數(shù)、貝塞爾函數(shù)...
總之,ODE 能產(chǎn)生并且已經(jīng)產(chǎn)生了各種各樣千奇百怪的函數(shù)~
數(shù)值解ODE?
能精確求出解析解的 ODE 其實(shí)是非常少的,所以很多時(shí)候我們都需要數(shù)值解法。?
ODE 的數(shù)值解已經(jīng)是一門非常成熟的學(xué)科了,這里我們也不多做介紹,僅引入最基本的由數(shù)學(xué)家歐拉提出來的迭代公式:
這里的 h 是步長。歐拉的解法來源很簡單,就是用:
來近似導(dǎo)數(shù)項(xiàng) x˙(t)。只要給定初始條件 x(0),我們就可以根據(jù) (4) 一步步迭代算出每個(gè)時(shí)間點(diǎn)的結(jié)果。
ODE與RNN?
ODE也是RNN
大家仔細(xì)對比 (4) 和 (1),發(fā)現(xiàn)有什么聯(lián)系了嗎?
在 (1) 中,t 是一個(gè)整數(shù)變量,在 (4) 中,t 是一個(gè)浮點(diǎn)變量,除此之外,(4) 跟 (1) 貌似就沒有什么明顯的區(qū)別了。事實(shí)上,在 (4) 中我們可以以 h 為時(shí)間單位,記 t=nh,那么 (4) 變成了:
可以看到現(xiàn)在 (6) 中的時(shí)間變量 n 也是整數(shù)了。這樣一來,我們就知道:ODE 的歐拉解法 (4) 實(shí)際上就是 RNN 的一個(gè)特例罷了。這里我們也許可以間接明白為什么 RNN 的擬合能力如此之強(qiáng)了(尤其是對于時(shí)間序列數(shù)據(jù)),我們看到 ODE 可以產(chǎn)生很多復(fù)雜的函數(shù),而 ODE 只不過是 RNN 的一個(gè)特例罷了,所以 RNN 也就可以產(chǎn)生更為復(fù)雜的函數(shù)了。?
用RNN解ODE?
于是,我們就可以寫一個(gè) RNN 來解 ODE 了,比如《兩生物種群競爭模型》[2] 中的例子:
我們可以寫出:
from?keras.layers?import?Layer
import?keras.backend?as?K
class?ODE_RNN(Layer):
????def?__init__(self,?steps,?h,?**kwargs):
????????self.steps?=?steps
????????self.h?=?h
????????super(ODE_RNN,?self).__init__(**kwargs)
????def?step_do(self,?step_in,?states):?#?定義每一步的迭代
????????x?=?states[0]
????????r1,r2,a1,a2,iN1,iN2?=?0.1,0.3,0.0001,0.0002,0.002,0.003
????????_1?=?r1?*?x[:,0]?*?(1?-?iN1?*?x[:,0])?-?a1?*?x[:,0]?*?x[:,1]
????????_2?=?r2?*?x[:,1]?*?(1?-?iN2?*?x[:,1])?-?a2?*?x[:,0]?*?x[:,1]
????????_1?=?K.expand_dims(_1,?1)
????????_2?=?K.expand_dims(_2,?1)
????????_?=?K.concatenate([_1,?_2],?1)
????????step_out?=?x?+?self.h?*?_
????????return?step_out,?[step_out]
????def?call(self,?inputs):?#?這里的inputs就是初始條件
????????init_states?=?[inputs]
????????zeros?=?K.zeros((K.shape(inputs)[0],
?????????????????????????self.steps,
?????????????????????????K.shape(inputs)[1]))?#?迭代過程用不著外部輸入,所以
??????????????????????????????????????????????#?指定一個(gè)全零輸入,只為形式上的傳入
????????outputs?=?K.rnn(self.step_do,?zeros,?init_states)?#?循環(huán)執(zhí)行step_do函數(shù)
????????return?outputs[1]?#?這次我們輸出整個(gè)結(jié)果序列
????def?compute_output_shape(self,?input_shape):
????????return?(input_shape[0],?self.steps,?input_shape[1])
from?keras.models?import?Sequential
import?numpy?as?np
import?matplotlib.pyplot?as?plt
steps,h?=?1000,0.1
M?=?Sequential()
M.add(ODE_RNN(steps,?h,?input_shape=(2,)))
M.summary()
#?直接前向傳播就輸出解了
result?=?M.predict(np.array([[100,?150]]))[0]?#?以[100,?150]為初始條件進(jìn)行演算
times?=?np.arange(1,?steps+1)?*?h
#?繪圖
plt.plot(times,?result[:,0])
plt.plot(times,?result[:,1])
plt.savefig('test.png')
整個(gè)過程很容易理解,只不過有兩點(diǎn)需要指出一下。首先,由于方程組 (7) 只有兩維,而且不容易寫成矩陣運(yùn)算,因此我在 step_do 函數(shù)中是直接逐位操作的(代碼中的 x[:,0],x[:,1]),如果方程本身維度較高,而且能寫成矩陣運(yùn)算,那么直接利用矩陣運(yùn)算寫會(huì)更加高效;然后,我們可以看到,寫完整個(gè)模型之后,直接 predict 就輸出結(jié)果了,不需要“訓(xùn)練”。
▲?RNN解兩物種的競爭模型
反推ODE參數(shù)
前一節(jié)的介紹也就是說,RNN 的前向傳播跟 ODE 的歐拉解法是對應(yīng)的,那么反向傳播又對應(yīng)什么呢?
在實(shí)際問題中,有一類問題稱為“模型推斷”,它是在已知實(shí)驗(yàn)數(shù)據(jù)的基礎(chǔ)上,猜測這批數(shù)據(jù)符合的模型(機(jī)理推斷)。這類問題的做法大概分兩步,第一步是猜測模型的形式,第二步是確定模型的參數(shù)。假定這批數(shù)據(jù)可以由一個(gè) ODE 描述,并且這個(gè) ODE 的形式已經(jīng)知道了,那么就需要估計(jì)里邊的參數(shù)。
如果能夠用公式完全解出這個(gè) ODE,那么這就只是一個(gè)非常簡單的回歸問題罷了。但前面已經(jīng)說過,多數(shù) ODE 都沒有公式解,所以數(shù)值方法就必須了。這其實(shí)就是 ODE 對應(yīng)的 RNN 的反向傳播所要做的事情:前向傳播就是解 ODE(RNN 的預(yù)測過程),反向傳播自然就是推斷 ODE 的參數(shù)了(RNN 的訓(xùn)練過程)。這是一個(gè)非常有趣的事實(shí):ODE 的參數(shù)推斷是一個(gè)被研究得很充分的課題,然而在深度學(xué)習(xí)這里,只是 RNN 的一個(gè)最基本的應(yīng)用罷了。
我們把剛才的例子的微分方程的解數(shù)據(jù)保存下來,然后只取幾個(gè)點(diǎn),看看能不能反推原來的微分方程出來,解數(shù)據(jù)為:
假設(shè)就已知這有限的點(diǎn)數(shù)據(jù),然后假定方程 (7) 的形式,求方程的各個(gè)參數(shù)。我們修改一下前面的代碼:
from?keras.layers?import?Layer
import?keras.backend?as?K
def?my_init(shape,?dtype=None):?#?需要定義好初始化,這相當(dāng)于需要實(shí)驗(yàn)估計(jì)參數(shù)的量級
????return?K.variable([0.1,?0.1,?0.001,?0.001,?0.001,?0.001])
class?ODE_RNN(Layer):
????def?__init__(self,?steps,?h,?**kwargs):
????????self.steps?=?steps
????????self.h?=?h
????????super(ODE_RNN,?self).__init__(**kwargs)
????def?build(self,?input_shape):?#?將原來的參數(shù)設(shè)為可訓(xùn)練的參數(shù)
????????self.kernel?=?self.add_weight(name='kernel',?
??????????????????????????????????????shape=(6,),
??????????????????????????????????????initializer=my_init,
??????????????????????????????????????trainable=True)
????def?step_do(self,?step_in,?states):?#?定義每一步的迭代
????????x?=?states[0]
????????r1,r2,a1,a2,iN1,iN2?=?(self.kernel[0],?self.kernel[1],
???????????????????????????????self.kernel[2],?self.kernel[3],
???????????????????????????????self.kernel[4],?self.kernel[5])
????????_1?=?r1?*?x[:,0]?*?(1?-?iN1?*?x[:,0])?-?a1?*?x[:,0]?*?x[:,1]
????????_2?=?r2?*?x[:,1]?*?(1?-?iN2?*?x[:,1])?-?a2?*?x[:,0]?*?x[:,1]
????????_1?=?K.expand_dims(_1,?1)
????????_2?=?K.expand_dims(_2,?1)
????????_?=?K.concatenate([_1,?_2],?1)
????????step_out?=?x?+?self.h?*?K.clip(_,?-1e5,?1e5)?#?防止梯度爆炸
????????return?step_out,?[step_out]
????def?call(self,?inputs):?#?這里的inputs就是初始條件
????????init_states?=?[inputs]
????????zeros?=?K.zeros((K.shape(inputs)[0],
?????????????????????????self.steps,
?????????????????????????K.shape(inputs)[1]))?#?迭代過程用不著外部輸入,所以
??????????????????????????????????????????????#?指定一個(gè)全零輸入,只為形式上的傳入
????????outputs?=?K.rnn(self.step_do,?zeros,?init_states)?#?循環(huán)執(zhí)行step_do函數(shù)
????????return?outputs[1]?#?這次我們輸出整個(gè)結(jié)果序列
????def?compute_output_shape(self,?input_shape):
????????return?(input_shape[0],?self.steps,?input_shape[1])
from?keras.models?import?Sequential
from?keras.optimizers?import?Adam
import?numpy?as?np
import?matplotlib.pyplot?as?plt
steps,h?=?50,?1?#?用大步長,減少步數(shù),削弱長時(shí)依賴,也加快推斷速度
series?=?{0:?[100,?150],
??????????10:?[165,?283],
??????????15:?[197,?290],
??????????30:?[280,?276],
??????????36:?[305,?269],
??????????40:?[318,?266],
??????????42:?[324,?264]}
M?=?Sequential()
M.add(ODE_RNN(steps,?h,?input_shape=(2,)))
M.summary()
#?構(gòu)建訓(xùn)練樣本
#?其實(shí)就只有一個(gè)樣本序列,X為初始條件,Y為后續(xù)時(shí)間序列
X?=?np.array([series[0]])
Y?=?np.zeros((1,?steps,?2))
for?i,j?in?series.items():
????if?i?!=?0:
????????Y[0,?int(i/h)-1]?+=?series[i]
#?自定義loss
#?在訓(xùn)練的時(shí)候,只考慮有數(shù)據(jù)的幾個(gè)時(shí)刻,沒有數(shù)據(jù)的時(shí)刻被忽略
def?ode_loss(y_true,?y_pred):
????T?=?K.sum(K.abs(y_true),?2,?keepdims=True)
????T?=?K.cast(K.greater(T,?1e-3),?'float32')
????return?K.sum(T?*?K.square(y_true?-?y_pred),?[1,?2])
M.compile(loss=ode_loss,
??????????optimizer=Adam(1e-4))
M.fit(X,?Y,?epochs=10000)?#?用低學(xué)習(xí)率訓(xùn)練足夠多輪
#?用訓(xùn)練出來的模型重新預(yù)測,繪圖,比較結(jié)果
result?=?M.predict(np.array([[100,?150]]))[0]
times?=?np.arange(1,?steps+1)?*?h
plt.clf()
plt.plot(times,?result[:,0],?color='blue')
plt.plot(times,?result[:,1],?color='green')
plt.plot(series.keys(),?[i[0]?for?i?in?series.values()],?'o',?color='blue')
plt.plot(series.keys(),?[i[1]?for?i?in?series.values()],?'o',?color='green')
plt.savefig('test.png')
結(jié)果可以用一張圖來看:
▲?RNN做ODE的參數(shù)估計(jì)效果
(散點(diǎn):有限的實(shí)驗(yàn)數(shù)據(jù),曲線:估計(jì)出來的模型)
顯然結(jié)果是讓人滿意的。
又到總結(jié)
本文在一個(gè)一般的框架下介紹了 RNN 模型及其在 Keras 下的自定義寫法,然后揭示了 ODE 與 RNN 的聯(lián)系。在此基礎(chǔ)上,介紹了用 RNN 直接求解 ODE 以及用 RNN 反推 ODE 參數(shù)的基本思路。
需要提醒讀者的是,在 RNN 模型的反向傳播中,要謹(jǐn)慎地做好初始化和截?cái)嗵幚硖幚?#xff0c;并且選擇好學(xué)習(xí)率等,以防止梯度爆炸的出現(xiàn)(梯度消失只是優(yōu)化得不夠好,梯度爆炸則是直接崩潰了,解決梯度爆炸問題尤為重要)。
總之,梯度消失和梯度爆炸在 RNN 中是一個(gè)很經(jīng)典的困難,事實(shí)上,LSTM、GRU 等模型的引入,根本原因就是為了解決 RNN 的梯度消失問題,而梯度爆炸則是通過使用 tanh 或 sigmoid 激活函數(shù)來解決的。
但是如果用 RNN 解決 ODE 的話,我們就沒有選擇激活函數(shù)的權(quán)利了(激活函數(shù)就是 ODE 的一部分),所以只能謹(jǐn)慎地做好初始化及其他處理。據(jù)說,只要謹(jǐn)慎做好初始化,普通 RNN 中用 relu 作為激活函數(shù)都是無妨的。
相關(guān)鏈接
[1].?Tian Qi C, Yulia R, Jesse B, David D. Neural Ordinary Differential Equations. arXiv preprint arXiv:1806.07366, 2018.
[2]. 兩生物種群競爭模型
https://kexue.fm/archives/3120
點(diǎn)擊以下標(biāo)題查看作者其他文章:?
從無監(jiān)督構(gòu)建詞庫看「最小熵原理」
基于CNN的閱讀理解式問答模型:DGCNN
再談最小熵原理:飛象過河之句模版和語言結(jié)構(gòu)
再談變分自編碼器VAE:從貝葉斯觀點(diǎn)出發(fā)
變分自編碼器VAE:這樣做為什么能成?
簡明條件隨機(jī)場CRF介紹 | 附帶純Keras實(shí)現(xiàn)
▲?戳我查看招募詳情
#作 者 招 募#
讓你的文字被很多很多人看到,喜歡我們不如加入我們
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 進(jìn)入作者博客
總結(jié)
以上是生活随笔為你收集整理的貌离神合的RNN与ODE:花式RNN简介的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用PaddleFluid和Tensor
- 下一篇: COLING 2018 最佳论文解读:序