深度学习(11)-- GAN
TensorFlow (GAN)
目錄
- TensorFlow (GAN)
- 目錄
- 1、GAN
- 1.1 常見神經網絡形式
- 1.2 生成網絡
- 1.3 新手畫家 & 新手鑒賞家
- 1.4 GAN網絡
- 1.5 例子
- 1、GAN
1、GAN
今天我們會來說說現在最流行的一種生成網絡, 叫做 GAN, 又稱生成對抗網絡, 也是 Generative Adversarial Nets 的簡稱
1.1 常見神經網絡形式
神經網絡分很多種, 有普通的前向傳播神經網絡 , 有分析圖片的 CNN 卷積神經網絡 , 有分析序列化數據, 比如語音的 RNN 循環神經網絡 , 這些神經網絡都是用來輸入數據, 得到想要的結果, 我們看中的是這些神經網絡能很好的將數據與結果通過某種關系聯系起來
1.2 生成網絡
但是還有另外一種形式的神經網絡, 他不是用來把數據對應上結果的, 而是用來”憑空”捏造結果, 這就是我們要說的生成網絡啦. GAN 就是其中的一種形式. 那么 GAN 是怎么做到的呢? 當然這里的”憑空”并不是什么都沒有的空盒子, 而是一些隨機數.
對, 你沒聽錯, 我們就是用沒有意義的隨機數來生成有有意義的作品, 比如著名畫作. 當然, 這還不是全部, 這只是一個 GAN 的一部分而已, 這一部分的神經網絡我們可以想象成是一個新手畫家.
1.3 新手畫家 & 新手鑒賞家
畫家作畫都需要點靈感 , 他們都是依照自己的靈感來完成作品. 有了靈感不一定有用, 因為他的作畫技術并沒有我們想象得好, 畫出來有可能是一團糟. 這可怎么辦, 聰明的新手畫家找到了自己的一個正在學鑒賞的好朋友 – 新手鑒賞家.
可是新手鑒賞家也沒什么能耐, 他也不知道如何鑒賞著名畫作 , 所以坐在電腦旁邊的你實在看不下去了, 拿起幾個標簽往屏幕上一甩 , 然后新手鑒賞家就被你這樣一次次的甩來甩去著甩乖了, 慢慢也學會了怎么樣區分著名畫家的畫了. 重要的是, 新手鑒賞家和新手畫家是好朋友, 他們總愛分享學習到的東西.
所以新手鑒賞家告訴新手畫家, “你的畫實在太丑了, 你看看人家達芬奇, 你也學學它呀, 比如這里要多加一點, 這里要畫淡一點.” 就這樣, 新手鑒賞家將他從你這里所學到的知識都分享給了新手畫家, 讓好朋友新手畫家也能越畫越像達芬奇. 這就是 GAN 的整套流程, 我們在來理一下.
新手畫家用隨機靈感畫畫 , 新手鑒賞家會接收一些畫作, 但是他不知道這是新手畫家畫的還是著名畫家畫的, 他說出他的判斷, 你來糾正他的判斷, 新手鑒賞家一邊學如何判斷, 一邊告訴新手畫家要怎么畫才能畫得更像著名畫家, 新手畫家就能學習到如何從自己的靈感畫出更像著名畫家的畫了. GAN 也就這么回事.
1.4 GAN網絡
Generator 會根據隨機數來生成有意義的數據 , Discriminator 會學習如何判斷哪些是真實數據 , 哪些是生成數據, 然后將學習的經驗反向傳遞給 Generator, 讓 Generator 能根據隨機數生成更像真實數據的數據. 這樣訓練出來的 Generator 可以有很多用途, 比如最近有人就拿它來生成各種臥室的圖片.
甚至你還能玩點新花樣, 比如讓圖片來做加減法, 戴眼鏡的男人 減去 男人 加上 女人, 他居然能生成 戴眼鏡的女人的圖片. 甚至還能根據你隨便畫的幾筆草圖來生成可能是你需要的藍天白云大草地圖片. 哈哈, 看起來機器也能有想象力啦
1.5 例子
import tensorflow as tf import numpy as np import matplotlib.pyplot as plttf.set_random_seed(1) np.random.seed(1)# Hyper Parameters BATCH_SIZE = 64 LR_G = 0.0001 # learning rate for generator LR_D = 0.0001 # learning rate for discriminator N_IDEAS = 5 # think of this as number of ideas for generating an art work (Generator) ART_COMPONENTS = 15 # it could be total point G can draw in the canvas PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])# show our beautiful painting range plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound') plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound') plt.legend(loc='upper right') plt.show()def artist_works(): # painting from the famous artist (real target)a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]paintings = a * np.power(PAINT_POINTS, 2) + (a-1)return paintingswith tf.variable_scope('Generator'):G_in = tf.placeholder(tf.float32, [None, N_IDEAS]) # random ideas (could from normal distribution)G_l1 = tf.layers.dense(G_in, 128, tf.nn.relu)G_out = tf.layers.dense(G_l1, ART_COMPONENTS) # making a painting from these random ideaswith tf.variable_scope('Discriminator'):real_art = tf.placeholder(tf.float32, [None, ART_COMPONENTS], name='real_in') # receive art work from the famous artistD_l0 = tf.layers.dense(real_art, 128, tf.nn.relu, name='l')prob_artist0 = tf.layers.dense(D_l0, 1, tf.nn.sigmoid, name='out') # probability that the art work is made by artist# reuse layers for generatorD_l1 = tf.layers.dense(G_out, 128, tf.nn.relu, name='l', reuse=True) # receive art work from a newbie like Gprob_artist1 = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name='out', reuse=True) # probability that the art work is made by artistD_loss = -tf.reduce_mean(tf.log(prob_artist0) + tf.log(1-prob_artist1)) G_loss = tf.reduce_mean(tf.log(1-prob_artist1))train_D = tf.train.AdamOptimizer(LR_D).minimize(D_loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')) train_G = tf.train.AdamOptimizer(LR_G).minimize(G_loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator'))sess = tf.Session() sess.run(tf.global_variables_initializer())plt.ion() # something about continuous plotting for step in range(5000):artist_paintings = artist_works() # real painting from artistG_ideas = np.random.randn(BATCH_SIZE, N_IDEAS)G_paintings, pa0, Dl = sess.run([G_out, prob_artist0, D_loss, train_D, train_G], # train and get results{G_in: G_ideas, real_art: artist_paintings})[:3]if step % 50 == 0: # plottingplt.cla()plt.plot(PAINT_POINTS[0], G_paintings[0], c='#4AD631', lw=3, label='Generated painting',)plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % pa0.mean(), fontdict={'size': 15})plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -Dl, fontdict={'size': 15})plt.ylim((0, 3)); plt.legend(loc='upper right', fontsize=12); plt.draw(); plt.pause(0.01)plt.ioff() plt.show()
總結
以上是生活随笔為你收集整理的深度学习(11)-- GAN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python(16)-列表list,fo
- 下一篇: 深度学习(07)-- 经典CNN网络结构