日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當(dāng)前位置: 首頁 > 人文社科 > 生活经验 >内容正文

生活经验

keras 的 example 文件 mnist_acgan.py 解析

發(fā)布時(shí)間:2023/11/27 生活经验 32 豆豆
生活随笔 收集整理的這篇文章主要介紹了 keras 的 example 文件 mnist_acgan.py 解析 小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

這是一個(gè)gan網(wǎng)絡(luò),大致分為兩個(gè)神經(jīng)網(wǎng)絡(luò),一個(gè)是生成網(wǎng)絡(luò),另一個(gè)是判別網(wǎng)絡(luò)

判別網(wǎng)絡(luò)的結(jié)構(gòu)大致如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 12544)        387840      input_1[0][0]
__________________________________________________________________________________________________
generation (Dense)              (None, 1)            12545       sequential_1[1][0]
__________________________________________________________________________________________________
auxiliary (Dense)               (None, 10)           125450      sequential_1[1][0]
==================================================================================================
Total params: 525,835
Trainable params: 525,835
Non-trainable params: 0
__________________________________________________________________________________________________

其中 Sequential1 的網(wǎng)絡(luò)結(jié)構(gòu)為:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 14, 14, 32)        320
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 32)        0
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         73856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 7, 7, 128)         0
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 128)         0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 256)         295168
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 7, 7, 256)         0
_________________________________________________________________
dropout_4 (Dropout)          (None, 7, 7, 256)         0
_________________________________________________________________
flatten_1 (Flatten)          (None, 12544)             0
=================================================================
Total params: 387,840
Trainable params: 387,840
Non-trainable params: 0
_________________________________________________________________

就是跟定一張圖片,通過一堆卷積、激活、dropout之后,最后拉伸生成一個(gè)12544維度的一個(gè)向量,然后跟兩個(gè)Dense,一個(gè)是判斷是否為真圖片(generation?),另一個(gè)是判斷是哪個(gè)數(shù)字(auxiliary)

生成網(wǎng)絡(luò)的結(jié)構(gòu)大致如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_3 (InputLayer)            (None, 1)            0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 100)          0
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 100)       1000        input_3[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 1, 100)       0           input_2[0][0]embedding_1[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 28, 28, 1)    2656897     multiply_1[0][0]
==================================================================================================
Total params: 2,657,897
Trainable params: 2,657,321
Non-trainable params: 576
__________________________________________________________________________________________________

其中Sequential1的網(wǎng)絡(luò)結(jié)構(gòu)為:

____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #
====================================================================================================
dense_1 (Dense)                              (None, 3456)                            349056
____________________________________________________________________________________________________
reshape_1 (Reshape)                          (None, 3, 3, 384)                       0
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTranspose)         (None, 7, 7, 192)                       1843392
____________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization)   (None, 7, 7, 192)                       768
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTranspose)         (None, 14, 14, 96)                      460896
____________________________________________________________________________________________________
batch_normalization_2 (BatchNormalization)   (None, 14, 14, 96)                      384
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTranspose)         (None, 28, 28, 1)                       2401
====================================================================================================
Total params: 2,656,897
Trainable params: 2,656,321
Non-trainable params: 576
____________________________________________________________________________________________________

也就是有兩個(gè)輸入,一個(gè)是隨機(jī)數(shù)(input_2),另一個(gè)是類別(input_3),就是數(shù)字幾

其中輸入 input_3 經(jīng)過一個(gè)Embedding 之后和 和 input_2 相乘,這里是一個(gè)點(diǎn)乘,也叫內(nèi)積,相乘之后shape不變,生成一個(gè)100維的向量,再經(jīng)過Dense、Reshape 和 Conv2DTranspose 之后,生成一張28*28的黑白圖片

?

上面生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)合并起來,大致結(jié)構(gòu)為:

________________________________________________________________________________________________________________________
Layer (type)                           Output Shape               Param #       Connected to
========================================================================================================================
input_4 (InputLayer)                   (None, 100)                0
________________________________________________________________________________________________________________________
input_5 (InputLayer)                   (None, 1)                  0
________________________________________________________________________________________________________________________
model_2 (Model)                        (None, 28, 28, 1)          2657897       input_4[0][0]input_5[0][0]
________________________________________________________________________________________________________________________
model_1 (Model)                        [(None, 1), (None, 10)]    525835        model_2[1][0]
========================================================================================================================
Total params: 3,183,732
Trainable params: 2,657,321
Non-trainable params: 526,411
________________________________________________________________________________________________________________________

?

這里有一個(gè) train_on_batch 加上參數(shù) sample_weight ,這個(gè)sample_weight 是對應(yīng)?[y, aux_y] ,

            print(len(disc_sample_weight))print(len(disc_sample_weight[0]))print(len(disc_sample_weight[1]))tmp = [y, aux_y]print(len(tmp))print(len(tmp[0]))print(len(tmp[1]))

大致就是這么個(gè)意思,y,也就是是否為真實(shí),這個(gè)計(jì)算損失的結(jié)果就正常計(jì)算,稍微有一點(diǎn)就是真實(shí)圖片的y的 label 值為 0.95

aux_y的損失,由于對于新生成的圖片,計(jì)算其分類沒有啥意義,所以最初是把它的損失結(jié)果直接乘以0,而對于mnist庫中的圖片,把分類的損失乘以2,彌補(bǔ)一下

這種情況下,我們訓(xùn)練判別網(wǎng)絡(luò)?discriminator 一次

然后我們再生成一堆圖片,然后把是否為真圖片的標(biāo)簽,全部設(shè)置為0.95,然后訓(xùn)練一次?combined 網(wǎng)絡(luò),該網(wǎng)絡(luò)中?discriminator.trainable = False,所以這里僅訓(xùn)練了生成網(wǎng)絡(luò)

訓(xùn)練過程基本就是這些,其他代碼就是計(jì)算測試的損失和保存生成圖片

如下圖,效果不錯(cuò)

?

——————————————————————

總目錄

keras的example文件解析

總結(jié)

以上是生活随笔為你收集整理的keras 的 example 文件 mnist_acgan.py 解析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。