GAN 对抗生成网络代码实现
生活随笔
收集整理的這篇文章主要介紹了
GAN 对抗生成网络代码实现
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
作報告寫了ppt,這里po上?
更完整的介紹關(guān)注專欄生成對抗網(wǎng)絡(luò)Generative Adversarial Network
本篇的同名博客[生成對抗網(wǎng)絡(luò)GAN入門指南](3)GAN的工程實踐及基礎(chǔ)代碼
In?[1]:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import osIn?[2]:
#該函數(shù)將給出權(quán)重初始化的方法 def variable_init(size):in_dim = size[0]#計算隨機(jī)生成變量所服從的正態(tài)分布標(biāo)準(zhǔn)差w_stddev = 1. / tf.sqrt(in_dim / 2.)return tf.random_normal(shape=size, stddev=w_stddev)In?[3]:
#定義輸入矩陣的占位符,輸入層單元為784,None代表批量大小的占位,X代表輸入的真實圖片。占位符的數(shù)值類型為32位浮點(diǎn)型 X = tf.placeholder(tf.float32, shape=[None, 784])#定義判別器的權(quán)重矩陣和偏置項向量,由此可知判別網(wǎng)絡(luò)為三層全連接網(wǎng)絡(luò) D_W1 = tf.Variable(variable_init([784, 128])) D_b1 = tf.Variable(tf.zeros(shape=[128]))D_W2 = tf.Variable(variable_init([128, 1])) D_b2 = tf.Variable(tf.zeros(shape=[1]))theta_D = [D_W1, D_W2, D_b1, D_b2]#定義生成器的輸入噪聲為100維度的向量組,None根據(jù)批量大小確定 Z = tf.placeholder(tf.float32, shape=[None, 100])#定義生成器的權(quán)重與偏置項。輸入層為100個神經(jīng)元且接受隨機(jī)噪聲, #輸出層為784個神經(jīng)元,并輸出手寫字體圖片。生成網(wǎng)絡(luò)根據(jù)原論文為三層全連接網(wǎng)絡(luò) G_W1 = tf.Variable(variable_init([100, 128])) G_b1 = tf.Variable(tf.zeros(shape=[128]))G_W2 = tf.Variable(variable_init([128, 784])) G_b2 = tf.Variable(tf.zeros(shape=[784]))theta_G = [G_W1, G_W2, G_b1, G_b2]In?[4]:
#定義一個可以生成m*n階隨機(jī)矩陣的函數(shù),該矩陣的元素服從均勻分布,隨機(jī)生成的z就為生成器的輸入 def sample_Z(m, n):return np.random.uniform(-1., 1., size=[m, n])In?[5]:
#定義生成器 def generator(z):#第一層先計算 y=z*G_W1+G-b1,然后投入激活函數(shù)計算G_h1=ReLU(y),G_h1 為第二次層神經(jīng)網(wǎng)絡(luò)的輸出激活值G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)#以下兩個語句計算第二層傳播到第三層的激活結(jié)果,第三層的激活結(jié)果是含有784個元素的向量,該向量轉(zhuǎn)化28×28就可以表示圖像G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob)return G_probIn?[6]:
#定義判別器 def discriminator(x):#計算D_h1=ReLU(x*D_W1+D_b1),該層的輸入為含784個元素的向量D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)#計算第三層的輸出結(jié)果。因為使用的是Sigmoid函數(shù),則該輸出結(jié)果是一個取值為[0,1]間的標(biāo)量(見上述權(quán)重定義)#即判別輸入的圖像到底是真(=1)還是假(=0)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit)#返回判別為真的概率和第三層的輸入值,輸出D_logit是為了將其輸入tf.nn.sigmoid_cross_entropy_with_logits()以構(gòu)建損失函數(shù)return D_prob, D_logitIn?[7]:
#該函數(shù)用于輸出生成圖片 def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return fig交叉熵?fù)p失函數(shù)
函數(shù)的輸入是和,就是神經(jīng)網(wǎng)絡(luò)模型中的矩陣,且不需要經(jīng)過激活函數(shù)。而的shape和相同,即正確的標(biāo)注值。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
那么該函數(shù)的表達(dá)式為
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
In?[8]:
#輸入隨機(jī)噪聲z而輸出生成樣本 G_sample = generator(Z)#分別輸入真實圖片和生成的圖片,并投入判別器以判斷真?zhèn)?D_real, D_logit_real = discriminator(X) D_fake, D_logit_fake = discriminator(G_sample)#以下為原論文的判別器損失和生成器損失,但本實現(xiàn)并沒有使用該損失函數(shù) # D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) # G_loss = -tf.reduce_mean(tf.log(D_fake))# 我們使用交叉熵作為判別器和生成器的損失函數(shù),因為sigmoid_cross_entropy_with_logits內(nèi)部會對預(yù)測輸入執(zhí)行Sigmoid函數(shù), #所以我們?nèi)∨袆e器最后一層未投入激活函數(shù)的值,即D_h1*D_W2+D_b2。 #tf.ones_like(D_logit_real)創(chuàng)建維度和D_logit_real相等的全是1的標(biāo)注,真實圖片。 D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))#損失函數(shù)為兩部分,即E[log(D(x))]+E[log(1-D(G(z)))],將真的判別為假和將假的判別為真 D_loss = D_loss_real + D_loss_fake#同樣使用交叉熵構(gòu)建生成器損失函數(shù) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))#定義判別器和生成器的優(yōu)化方法為Adam算法,關(guān)鍵字var_list表明最小化損失函數(shù)所更新的權(quán)重矩陣 D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)In?[9]:
#選擇訓(xùn)練的批量大小和隨機(jī)生成噪聲的維度 mb_size = 128 Z_dim = 100#讀取數(shù)據(jù)集MNIST,并放在當(dāng)前目錄data文件夾下MNIST文件夾中,如果該地址沒有數(shù)據(jù),則下載數(shù)據(jù)至該文件夾 mnist = input_data.read_data_sets("./data/MNIST/", one_hot=True) Extracting ./data/MNIST/train-images-idx3-ubyte.gz Extracting ./data/MNIST/train-labels-idx1-ubyte.gz Extracting ./data/MNIST/t10k-images-idx3-ubyte.gz Extracting ./data/MNIST/t10k-labels-idx1-ubyte.gzIn?[10]:
#打開一個會話運(yùn)行計算圖 sess = tf.Session()#初始化所有定義的變量 sess.run(tf.global_variables_initializer())#如果當(dāng)前目錄下不存在out文件夾,則創(chuàng)建該文件夾 if not os.path.exists('out/'):os.makedirs('out/')#初始化,并開始迭代訓(xùn)練,100W次 i = 0 for it in range(20000):#每2000次輸出一張生成器生成的圖片if it % 2000 == 0:samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')i += 1plt.close(fig)#next_batch抽取下一個批量的圖片,該方法返回一個矩陣,即shape=[mb_size,784],每一行是一張圖片,共批量大小行X_mb, _ = mnist.train.next_batch(mb_size)#投入數(shù)據(jù)并根據(jù)優(yōu)化方法迭代一次,計算損失后返回?fù)p失值_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})#每迭代2000次輸出迭代數(shù)、生成器損失和判別器損失if it % 2000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'. format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print() Iter: 0 D loss: 1.671 G_loss: 1.718Iter: 2000 D loss: 0.05008 G_loss: 4.74Iter: 4000 D loss: 0.3667 G_loss: 4.85Iter: 6000 D loss: 0.3974 G_loss: 4.059Iter: 8000 D loss: 0.7007 G_loss: 2.628Iter: 10000 D loss: 0.4421 G_loss: 3.05Iter: 12000 D loss: 0.7872 G_loss: 2.562Iter: 14000 D loss: 0.7155 G_loss: 2.877Iter: 16000 D loss: 0.9827 G_loss: 2.042Iter: 18000 D loss: 0.7171 G_loss: 1.966?
總結(jié)
以上是生活随笔為你收集整理的GAN 对抗生成网络代码实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JS原型概念讲解
- 下一篇: VMware 12 安装 OS X 10