keras 的 example 文件 addition_rnn.py 解析
生活随笔
收集整理的這篇文章主要介紹了
keras 的 example 文件 addition_rnn.py 解析
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
該代碼實現了通過神經網絡來計算兩個三位數的相加
先生成一堆訓練數據,打印一下
print(questions[:10])
print(expected[:10])
結果為:
[' 31+991', ' 46+154', ' 0+2', ' 9+9', ' 1+7', ' 827+2', ' 97+09', ' 0+8', ' 5+3', ' 5+239']
['212 ', '515 ', '2 ', '18 ', '8 ', '730 ', '169 ', '8 ', '8 ', '937 ']
編碼的時候,questions是前面加空格,后面是真實的計算字符串,也就是右對齊
expected是后面加空格,也就是說expected字符串是左對齊
然后進行編碼,參考下面的questions編碼方式
31+991
[[ True False False False False False False False False False False False][False False False False False True False False False False False False][False False False True False False False False False False False False][False True False False False False False False False False False False][False False False False False False False False False False False True][False False False False False False False False False False False True][False False False True False False False False False False False False]]46+154
[[ True False False False False False False False False False False False][False False False False False False True False False False False False][False False False False False False False False True False False False][False True False False False False False False False False False False][False False False True False False False False False False False False][False False False False False False False True False False False False][False False False False False False True False False False False False]]0+2
[[ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][False False True False False False False False False False False False][False True False False False False False False False False False False][False False False False True False False False False False False False]]
上面的一行,分別對應[空格, +, 0,1,2,3,4,5,6,7,8,9],所以字符串進行了類似的one-hot編碼
expected也是一樣:
212
[[False False False False True False False False False False False False][False False False True False False False False False False False False][False False False False True False False False False False False False][ True False False False False False False False False False False False]]
515
[[False False False False False False False True False False False False][False False False True False False False False False False False False][False False False False False False False True False False False False][ True False False False False False False False False False False False]]
2
[[False False False False True False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False][ True False False False False False False False False False False False]]
因為expected中沒有加號,所以第二列永遠為False
?
x_train.shape和y_train.shape分別為(45000, 7, 12) (45000, 4, 12)
神經網絡模型為:
__________________________________________________________________________________________
Layer (type) Output Shape Param #
==========================================================================================
lstm_1 (LSTM) (None, 128) 72192
__________________________________________________________________________________________
repeat_vector_1 (RepeatVector) (None, 4, 128) 0
__________________________________________________________________________________________
lstm_2 (LSTM) (None, 4, 128) 131584
__________________________________________________________________________________________
time_distributed_1 (TimeDistributed) (None, 4, 12) 1548
==========================================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
__________________________________________________________________________________________
上面可以看到,兩個LSTM的輸出shape不一樣,一個是(None, 128),另一個是(None, 4, 128),這是因為第一個RNN的return_sequences為False,而第一個RNN的return_sequences為True
代碼解釋參考官方教程:
https://keras.io/zh/examples/addition_rnn/
?
——————————————————————
總目錄
keras的example文件解析
總結
以上是生活随笔為你收集整理的keras 的 example 文件 addition_rnn.py 解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 人工智能入门:keras的example
- 下一篇: keras 的 example 文件 a