keras 的 example 文件 mnist_acgan.py 解析
這是一個(gè)gan網(wǎng)絡(luò),大致分為兩個(gè)神經(jīng)網(wǎng)絡(luò),一個(gè)是生成網(wǎng)絡(luò),另一個(gè)是判別網(wǎng)絡(luò)
判別網(wǎng)絡(luò)的結(jié)構(gòu)大致如下:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 28, 28, 1) 0
__________________________________________________________________________________________________
sequential_1 (Sequential) (None, 12544) 387840 input_1[0][0]
__________________________________________________________________________________________________
generation (Dense) (None, 1) 12545 sequential_1[1][0]
__________________________________________________________________________________________________
auxiliary (Dense) (None, 10) 125450 sequential_1[1][0]
==================================================================================================
Total params: 525,835
Trainable params: 525,835
Non-trainable params: 0
__________________________________________________________________________________________________
其中 Sequential1 的網(wǎng)絡(luò)結(jié)構(gòu)為:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 14, 14, 32) 320
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 32) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 14, 14, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 64) 18496
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 128) 73856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 7, 7, 128) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 7, 7, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 7, 7, 256) 295168
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 7, 7, 256) 0
_________________________________________________________________
dropout_4 (Dropout) (None, 7, 7, 256) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 12544) 0
=================================================================
Total params: 387,840
Trainable params: 387,840
Non-trainable params: 0
_________________________________________________________________
就是跟定一張圖片,通過一堆卷積、激活、dropout之后,最后拉伸生成一個(gè)12544維度的一個(gè)向量,然后跟兩個(gè)Dense,一個(gè)是判斷是否為真圖片(generation?),另一個(gè)是判斷是哪個(gè)數(shù)字(auxiliary)
生成網(wǎng)絡(luò)的結(jié)構(gòu)大致如下:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
input_2 (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding_1 (Embedding) (None, 1, 100) 1000 input_3[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply) (None, 1, 100) 0 input_2[0][0]embedding_1[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential) (None, 28, 28, 1) 2656897 multiply_1[0][0]
==================================================================================================
Total params: 2,657,897
Trainable params: 2,657,321
Non-trainable params: 576
__________________________________________________________________________________________________
其中Sequential1的網(wǎng)絡(luò)結(jié)構(gòu)為:
____________________________________________________________________________________________________
Layer (type) Output Shape Param #
====================================================================================================
dense_1 (Dense) (None, 3456) 349056
____________________________________________________________________________________________________
reshape_1 (Reshape) (None, 3, 3, 384) 0
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTranspose) (None, 7, 7, 192) 1843392
____________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization) (None, 7, 7, 192) 768
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTranspose) (None, 14, 14, 96) 460896
____________________________________________________________________________________________________
batch_normalization_2 (BatchNormalization) (None, 14, 14, 96) 384
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTranspose) (None, 28, 28, 1) 2401
====================================================================================================
Total params: 2,656,897
Trainable params: 2,656,321
Non-trainable params: 576
____________________________________________________________________________________________________
也就是有兩個(gè)輸入,一個(gè)是隨機(jī)數(shù)(input_2),另一個(gè)是類別(input_3),就是數(shù)字幾
其中輸入 input_3 經(jīng)過一個(gè)Embedding 之后和 和 input_2 相乘,這里是一個(gè)點(diǎn)乘,也叫內(nèi)積,相乘之后shape不變,生成一個(gè)100維的向量,再經(jīng)過Dense、Reshape 和 Conv2DTranspose 之后,生成一張28*28的黑白圖片
?
上面生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)合并起來,大致結(jié)構(gòu)為:
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
========================================================================================================================
input_4 (InputLayer) (None, 100) 0
________________________________________________________________________________________________________________________
input_5 (InputLayer) (None, 1) 0
________________________________________________________________________________________________________________________
model_2 (Model) (None, 28, 28, 1) 2657897 input_4[0][0]input_5[0][0]
________________________________________________________________________________________________________________________
model_1 (Model) [(None, 1), (None, 10)] 525835 model_2[1][0]
========================================================================================================================
Total params: 3,183,732
Trainable params: 2,657,321
Non-trainable params: 526,411
________________________________________________________________________________________________________________________
?
這里有一個(gè) train_on_batch 加上參數(shù) sample_weight ,這個(gè)sample_weight 是對應(yīng)?[y, aux_y] ,
print(len(disc_sample_weight))print(len(disc_sample_weight[0]))print(len(disc_sample_weight[1]))tmp = [y, aux_y]print(len(tmp))print(len(tmp[0]))print(len(tmp[1]))
大致就是這么個(gè)意思,y,也就是是否為真實(shí),這個(gè)計(jì)算損失的結(jié)果就正常計(jì)算,稍微有一點(diǎn)就是真實(shí)圖片的y的 label 值為 0.95
aux_y的損失,由于對于新生成的圖片,計(jì)算其分類沒有啥意義,所以最初是把它的損失結(jié)果直接乘以0,而對于mnist庫中的圖片,把分類的損失乘以2,彌補(bǔ)一下
這種情況下,我們訓(xùn)練判別網(wǎng)絡(luò)?discriminator 一次
然后我們再生成一堆圖片,然后把是否為真圖片的標(biāo)簽,全部設(shè)置為0.95,然后訓(xùn)練一次?combined 網(wǎng)絡(luò),該網(wǎng)絡(luò)中?discriminator.trainable = False,所以這里僅訓(xùn)練了生成網(wǎng)絡(luò)
訓(xùn)練過程基本就是這些,其他代碼就是計(jì)算測試的損失和保存生成圖片
如下圖,效果不錯(cuò)
?
——————————————————————
總目錄
keras的example文件解析
總結(jié)
以上是生活随笔為你收集整理的keras 的 example 文件 mnist_acgan.py 解析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 天空之城(君をのせて)主题曲
- 下一篇: keras 的 example 文件 m