生成对抗网络(GAN)
生活随笔
收集整理的這篇文章主要介紹了
生成对抗网络(GAN)
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
學習目標
- 目標
- 了解GAN的作用
- 說明GAN的訓練過程
- 知道DCGAN的結(jié)構(gòu)
- 應用
- 應用DCGAN模型實現(xiàn)手寫數(shù)字的生成
5.1.1 GAN能做什么
GAN是非監(jiān)督式學習的一種方法,在2014年被提出。
GAN主要用途:
- 生成以假亂真的圖片
?
- 生成視頻、模型
5.1.2 什么GAN
5.1.2.1 定義
生成對抗網(wǎng)絡(luò)(Generative Adversarial Network,簡稱GAN),主要結(jié)構(gòu)包括一個生成器G(Generator)和一個判別器D(Discriminator)。
?
- 生成器(Generator),能夠輸入一個向量,輸出需要生成固定大小的像素圖像
- 判別器(Discriminator),用來判別圖片是真的還是假的,輸入圖片(訓練的數(shù)據(jù)或者生成的數(shù)據(jù)),輸出為判別圖片的標簽
5.1.2.2 理解
- 思想:從訓練庫里獲取很多訓練樣本,從而學習這些訓練案例生成的概率分布
?
- 黑色虛線:真是樣本的分布
- 綠色實線:生成樣本的分布
- 藍色虛線:判別器判斷的概率分布
- zz表示噪聲,zz到xx表示生成器生成的分布映射
過程分析:
- 1、定義GAN結(jié)構(gòu)生成數(shù)據(jù)
- (a)(a)狀態(tài)處于最初始的狀態(tài),生成器生成的分布和真實分布區(qū)別較大,并且判別器判別出樣本的概率不穩(wěn)定
- 2、在真實數(shù)據(jù)上訓練 n epochs判別器,產(chǎn)生fake(假數(shù)據(jù))并訓練判別器識別為假
- 通過多次訓練判別器來達到(b)(b)樣本狀態(tài),此時判別樣本區(qū)分得非常顯著
- 3、訓練生成器達到欺騙判別器的效果
- 訓練生成器之后達到(c)(c)樣本狀態(tài),此時生成器分布相比之前,逼近了真實樣本分布。經(jīng)過多次反復訓練迭代之后。
- 最終希望能夠達到(d)(d)狀態(tài),生成樣本分布擬合于真實樣本分布,并且判別器分辨不出樣本是生成的還是真實的。
5.1.2.3 訓練損失
?
- V(G, D)V(G,D):表示P_ x和 P_z 的差異程度。
- \max \limits_DV(D, G)?D?max??V(D,G)?:固定生成器G, 盡可能地讓判別器能夠最大化地判別出樣本來自于真實數(shù)據(jù)還是生成的數(shù)據(jù)
- \min \limits_G L?G?min??L:固定判別器D的條件下得到生成器G,能夠最小化真實樣本與生成樣本的差異。
整個優(yōu)化我們其實只看做一個部分:
- 判別器:相當于一個分類器,判斷圖片的真?zhèn)?#xff0c;二分類問題,使用交叉熵損失
對于真實樣本:對數(shù)預測概率損失,提高預測的概率
?
對于生成樣本:對數(shù)預測概率損失,降低預測概率
?
最終可以這樣:
?
5.1.2.4 G、D結(jié)構(gòu)
G、D結(jié)構(gòu)是兩個網(wǎng)絡(luò),特點是能夠反向傳播可導計算要介紹G、D結(jié)構(gòu),需要區(qū)分不同版本的GAN。
- 2014年最開始的模型:
- G、D都是multilayer perceptron(MLP)
- 缺點:實踐證明訓練難度大,效果不行
- 2015:使用卷積神經(jīng)網(wǎng)絡(luò)+GAN(DCGAN(Deep Convolutional GAN))
- 改進:
- 1、判別器D中取出pooling,全部變成卷積、生成器G中使用反卷積(下圖)
- 2、D、G中都增加了BN層
- 3、去除了所有的全連接層
- 4、判別器D中全部使用Leaky ReLU,生成器除了最后輸出層使用tanh其它層全換成ReLU
- 改進:
?
5.1.3 案例:GAN生成手寫數(shù)字圖像
5.1.3.1 案例演示與結(jié)果顯示
- 迭代不同次數(shù)生成的圖片效果
- 1次
?
- 50
?
- 2000次
?
5.1.3.2 代碼步驟流程
- 初始化GAN模型結(jié)構(gòu)
- init_model(self)
- 判別器:CNN,build_discriminator
- 生成器:CNN,build_generator
- 訓練過程:train(self, epochs, batch_size=128, save_interval=50)
- 訓練判別器
- 訓練生成器
- 生成圖片保存
5.1.3.3 代碼編寫
- 1、模型類
class DCGAN():def __init__(self):# 輸入圖片的形狀self.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)
- 2、初始化GAN模型結(jié)構(gòu)
- 建立D判別器CNN結(jié)構(gòu),初始化判別器訓練優(yōu)化參數(shù)
- 聯(lián)合建立G生成器CNN結(jié)構(gòu),初始化生成器訓練優(yōu)化參數(shù)
- 輸入噪點數(shù)據(jù),輸出預測的類別概率
- 注意生成器訓練時,判別器不進行訓練
- from keras.optimizers import Adam
def init_model(self):# 生成原始噪點數(shù)據(jù)大小self.latent_dim = 100optimizer = Adam(0.0002, 0.5)# 1、建立判別器訓練參數(shù)# 選擇損失,優(yōu)化器,以及衡量準確率self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 2、聯(lián)合建立生成器訓練參數(shù),指定生成器損失self.generator = self.build_generator()z = Input(shape=(self.latent_dim,))img = self.generator(z)# 合并模型的損失,并且之后只訓練生成器,判別器不訓練self.discriminator.trainable = Falsevalid = self.discriminator(img)# 訓練生成器欺騙判別器self.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
- 定義模型的判別器
- from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model
def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)
- 定義模型的生成器
- CNN結(jié)構(gòu)
def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)
- 3、訓練模型代碼
- from keras.datasets import mnist
- import matplotlib.pyplot as plt
- import numpy as np
- model:train_on_batch(feature, target)
def train(self, epochs, batch_size=32):# 加載手寫數(shù)字(X_train, _), (_, _) = mnist.load_data()# 進行歸一化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 正負樣本的目標值建立valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# 1、訓練判別器# 選擇隨機的一些真實樣本idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 生成器產(chǎn)生假樣本noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)# 訓練判別器過程d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# 計算平均兩部分損失d_loss = np.add(d_loss_real, d_loss_fake) / 2# 2、訓練生成器,停止判別器# 合并訓練,并停止訓練判別器# 用目標值為1去訓練,目的使得生成器生成的樣本越來越接近真是樣本g_loss = self.combined.train_on_batch(noise, valid)# 畫出結(jié)果print("迭代次數(shù):%d [D 損失: %f, 準確率: %.2f%], [G 損失: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))# 保存生成的圖片if epoch % 50 == 0:self.save_imgs(epoch)
- 保存生成的圖片
def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig("./images/mnist_%d.png" % epoch)plt.close()
5.1.4 總結(jié)
- 掌握GAN模型的原理過程
- 掌握GAN手寫數(shù)字的訓練過程
總結(jié)
以上是生活随笔為你收集整理的生成对抗网络(GAN)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: seq2seq与Attention机制
- 下一篇: 自动编码器