keras 的 example 文件 cnn_seq2seq.py 解析
該代碼是實(shí)現(xiàn)一個(gè)翻譯功能,好像是英語翻譯為法語,嗯,我看不懂法語
首先這個(gè)代碼有一個(gè)bug,本人提交了一個(gè)pull request來修復(fù),
https://github.com/keras-team/keras/pull/13863/commits/fd44e03a9d17c05aaecc620f8d88ef0fd385254b
但由于官方長(zhǎng)久不維護(hù),所以至今尚未合并,
需要把第68行改為:
input_text, target_text, _ = line.split('\t')
然后根據(jù)訓(xùn)練數(shù)據(jù),對(duì)字母進(jìn)行編碼,其中target_token_index中添加了兩個(gè)字符,開始符號(hào) '\t' 和結(jié)束符合 '\n':
print(input_token_index)
{' ': 0, '!': 1, '$': 2, '%': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '0': 9, '1': 10, '2': 11, '3': 12, '5': 13, '6': 14, '7': 15, '8': 16, '9': 17, ':': 18, '?': 19, 'A': 20, 'B': 21, 'C': 22, 'D': 23, 'E': 24, 'F': 25, 'G': 26, 'H': 27, 'I': 28, 'J': 29, 'K': 30, 'L': 31, 'M': 32, 'N': 33, 'O': 34, 'P': 35, 'Q': 36, 'R': 37, 'S': 38, 'T': 39, 'U': 40, 'V': 41, 'W': 42, 'Y': 43, 'a': 44, 'b': 45, 'c': 46, 'd': 47, 'e': 48, 'f': 49, 'g': 50, 'h': 51, 'i': 52, 'j': 53, 'k': 54, 'l': 55, 'm': 56, 'n': 57, 'o': 58, 'p': 59, 'q': 60, 'r': 61, 's': 62, 't': 63, 'u': 64, 'v': 65, 'w': 66, 'x': 67, 'y': 68, 'z': 69}
print(target_token_index)
{'\t': 0, '\n': 1, ' ': 2, '!': 3, '$': 4, '%': 5, '&': 6, "'": 7, '(': 8, ')': 9, ',': 10, '-': 11, '.': 12, '0': 13, '1': 14, '2': 15, '3': 16, '5': 17, '8': 18, '9': 19, ':': 20, '?': 21, 'A': 22, 'B': 23, 'C': 24, 'D': 25, 'E': 26, 'F': 27, 'G': 28, 'H': 29, 'I': 30, 'J': 31, 'K': 32, 'L': 33, 'M': 34, 'N': 35, 'O': 36, 'P': 37, 'Q': 38, 'R': 39, 'S': 40, 'T': 41, 'U': 42, 'V': 43, 'Y': 44, 'a': 45, 'b': 46, 'c': 47, 'd': 48, 'e': 49, 'f': 50, 'g': 51, 'h': 52, 'i': 53, 'j': 54, 'k': 55, 'l': 56, 'm': 57, 'n': 58, 'o': 59, 'p': 60, 'q': 61, 'r': 62, 's': 63, 't': 64, 'u': 65, 'v': 66, 'x': 67, 'y': 68, 'z': 69, '\xa0': 70, '?': 71, '?': 72, 'à': 73, '?': 74, 'é': 75, 'ê': 76, 'à': 77, 'a': 78, '?': 79, 'è': 80, 'é': 81, 'ê': 82, '?': 83, '?': 84, '?': 85, '?': 86, 'ù': 87, '?': 88, '?': 89, '\u2009': 90, '’': 91, '\u202f': 92}
對(duì),這個(gè)演示示例中不是對(duì)word進(jìn)行編碼,而是對(duì)字母進(jìn)行編碼,
至于原因,我分析應(yīng)該是這樣的,字母數(shù)量比較少,這個(gè)索引數(shù)也不過只有70個(gè)而已,但如果對(duì)單詞進(jìn)行編碼,那隨隨便便就上千個(gè),維度超大,后面再運(yùn)算的時(shí)候,需要占用極大的內(nèi)存和GPU
?
然后對(duì)輸入輸出的句子手動(dòng)進(jìn)行one-hot編碼:
在預(yù)處理中,target_text 的首位補(bǔ)了一個(gè)'\t',代表句子開始了,末尾補(bǔ)了一個(gè)'\n',代表句子結(jié)束了
輸入數(shù)據(jù)的尺寸為:
encoder_input_data.shape (10000, 16, 70)
decoder_input_data.shape (10000, 59, 93)
decoder_target_data.shape (10000, 59, 93)
而這個(gè)decoder_input_data 和?decoder_target_data 都是翻譯后的句子,只不過?decoder_target_data 比 decoder_input_data 提前一位,decoder_input_data 的第一位是?'\t', 第二位才是真實(shí)內(nèi)容,而 decoder_target_data 的第一位直接就是真實(shí)內(nèi)容了。
為什么會(huì)把翻譯的結(jié)果作為模型的輸入?
因?yàn)樵谟?xùn)練模型時(shí),下一位的輸出會(huì)依賴上一位的值,而在神經(jīng)網(wǎng)絡(luò)最開始的時(shí)候,如果預(yù)測(cè)的第一位錯(cuò)了,在預(yù)測(cè)第二位的時(shí)候,就會(huì)有一個(gè)錯(cuò)誤的輸入,我們這時(shí)候根據(jù)一個(gè)錯(cuò)誤的輸入去優(yōu)化神經(jīng)網(wǎng)絡(luò)是走在了錯(cuò)誤的方向,所以我們會(huì)輔助提供一個(gè)正確的值,這樣神經(jīng)網(wǎng)絡(luò)才是向正確的方向優(yōu)化
?
神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) (None, None, 93) 0
__________________________________________________________________________________________________
input_1 (InputLayer) (None, None, 70) 0
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, None, 256) 71680 input_2[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, None, 256) 54016 input_1[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D) (None, None, 256) 196864 conv1d_4[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, None, 256) 196864 conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, None, 256) 196864 conv1d_5[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, None, 256) 196864 conv1d_2[0][0]
__________________________________________________________________________________________________
dot_1 (Dot) (None, None, None) 0 conv1d_6[0][0]conv1d_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, None, None) 0 dot_1[0][0]
__________________________________________________________________________________________________
dot_2 (Dot) (None, None, 256) 0 activation_1[0][0]conv1d_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, None, 512) 0 dot_2[0][0]conv1d_6[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, None, 64) 98368 concatenate_1[0][0]
__________________________________________________________________________________________________
conv1d_8 (Conv1D) (None, None, 64) 12352 conv1d_7[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, None, 93) 6045 conv1d_8[0][0]
==================================================================================================
Total params: 1,029,917
Trainable params: 1,029,917
Non-trainable params: 0
__________________________________________________________________________________________________
在預(yù)測(cè)的時(shí)候,encoder_input_data 就是輸入的句子,decoder_input_data 是一個(gè)除第一位設(shè)置為開始符號(hào)'\t'外,其余位均為0的結(jié)構(gòu),在預(yù)測(cè)出第一位?decoder_target_data 后,把預(yù)測(cè)的字符追加到?decoder_input_data 后面一位,然后通過 for 循環(huán)預(yù)測(cè)下一位,以此類推,直到預(yù)期長(zhǎng)度
因?yàn)轭A(yù)測(cè)出的結(jié)果為編號(hào),需要反向索引為字符,而在反向索引時(shí)如果遇到結(jié)束符?'\n',就表示句子結(jié)束,得到了完整的預(yù)測(cè)結(jié)果
?
____________________________________________________
代碼?lstm_seq2seq.py 的數(shù)據(jù)預(yù)處理和上面一致,就不另外寫一篇了,神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)為:
______________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==============================================================================================================
input_1 (InputLayer) (None, None, 70) 0
______________________________________________________________________________________________________________
input_2 (InputLayer) (None, None, 93) 0
______________________________________________________________________________________________________________
lstm_1 (LSTM) [(None, 256), (None, 256), (None, 256)] 334848 input_1[0][0]
______________________________________________________________________________________________________________
lstm_2 (LSTM) [(None, None, 256), (None, 256), (None, 256)] 358400 input_2[0][0] lstm_1[0][1] lstm_1[0][2]
______________________________________________________________________________________________________________
dense_1 (Dense) (None, None, 93) 23901 lstm_2[0][0]
==============================================================================================================
Total params: 717,149
Trainable params: 717,149
Non-trainable params: 0
______________________________________________________________________________________________________________
——————————————————————
總目錄
keras的example文件解析
總結(jié)
以上是生活随笔為你收集整理的keras 的 example 文件 cnn_seq2seq.py 解析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: keras 的 example 文件 c
- 下一篇: keras 的 example 文件 d