keras 的 example 文件 lstm_stateful.py 解析
該程序要通過一個LSTM來實現擬合窗口平均數的功能
先看輸入輸出數據,
print(x_train[:10])
[[[-0.08453234]][[ 0.02169589]][[ 0.07949955]][[ 0.00898136]][[ 0.0405444 ]][[-0.0227726 ]][[ 0.03033169]][[ 0.03801032]][[ 0.04372695]][[ 0.03803725]]]
print(y_train[:10])
[[-0.03537864][-0.03141822][ 0.05059772][ 0.04424045][ 0.02476288][ 0.0088859 ][ 0.00377955][ 0.03417101][ 0.04086864][ 0.0408821 ]]
?y_train就是 x_train 兩兩數的平均值,不過 x_train 的最初的第一個數舍去了,看起來 y_train 的第一個數沒什么道理似的,這個不必關心
x_train.shape: ?(800, 1, 1)
y_train.shape: ?(800, 1)
x_test.shape: ?(200, 1, 1)
y_test.shape: ?(200, 1)
然后是神經網絡結構,無論stateful是否為True,結構都是一樣的:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (1, 20) 1760
_________________________________________________________________
dense_1 (Dense) (1, 1) 21
=================================================================
Total params: 1,781
Trainable params: 1,781
Non-trainable params: 0
_________________________________________________________________
注意:在stateful = True 時,我們要在fit中手動使得shuffle = False
預測結果:
?
從訓練時的打印來看:
stateful=True 時,
Epoch 10 / 10
Train on 800 samples, validate on 200 samples
Epoch 1/1
800/800 [==============================] - 2s 2ms/step - loss: 4.9922e-06 - val_loss: 2.9957e-06
而?stateful=False 時,
Epoch 10/10
800/800 [==============================] - 2s 2ms/step - loss: 8.5024e-04 - val_loss: 9.6397e-04
原理部分介紹可以參考
https://blog.csdn.net/qq_27586341/article/details/88239404
——————————————————————
總目錄
keras的example文件解析
總結
以上是生活随笔為你收集整理的keras 的 example 文件 lstm_stateful.py 解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: keras 的 example 文件 i
- 下一篇: keras 的 example 文件 l