一文读懂生成对抗网络(GANs)
GAN網(wǎng)絡(luò)是近兩年深度學(xué)習(xí)領(lǐng)域的新秀,火的不行,本文旨在淺顯理解傳統(tǒng)GAN,分享學(xué)習(xí)心得。現(xiàn)有GAN網(wǎng)絡(luò)大多數(shù)代碼實(shí)現(xiàn)使用Python、torch等語(yǔ)言,這里,后面用matlab搭建一個(gè)簡(jiǎn)單的GAN網(wǎng)絡(luò),便于理解GAN原理。
GAN的鼻祖之作是2014年NIPS一篇文章:Generative Adversarial Net(https://arxiv.org/abs/1406.2661),可以細(xì)細(xì)品味。
▌開(kāi)始
我們知道GAN的思想是是一種二人零和博弈思想(two-player game),博弈雙方的利益之和是一個(gè)常數(shù),比如兩個(gè)人掰手腕,假設(shè)總的空間是一定的,你的力氣大一點(diǎn),那你就得到的空間多一點(diǎn),相應(yīng)的我的空間就少一點(diǎn),相反我力氣大我就得到的多一點(diǎn),但有一點(diǎn)是確定的就是,我兩的總空間是一定的,這就是二人博弈,但是呢總利益是一定的。
引申到GAN里面就是可以看成,GAN中有兩個(gè)這樣的博弈者,一個(gè)人名字是生成模型(G),另一個(gè)人名字是判別模型(D)。他們各自有各自的功能。
相同點(diǎn)是:
這兩個(gè)模型都可以看成是一個(gè)黑匣子,接受輸入然后有一個(gè)輸出,類(lèi)似一個(gè)函數(shù),一個(gè)輸入輸出映射。
不同點(diǎn)是:
生成模型功能:比作是一個(gè)樣本生成器,輸入一個(gè)噪聲/樣本,然后把它包裝成一個(gè)逼真的樣本,也就是輸出。
判別模型:比作一個(gè)二分類(lèi)器(如同0-1分類(lèi)器),來(lái)判斷輸入的樣本是真是假。(就是輸出值大于0.5還是小于0.5)
直接上一張個(gè)人覺(jué)得解釋的好的圖說(shuō)明:
在之前,我們首先明白在使用GAN的時(shí)候的2個(gè)問(wèn)題
我們有什么?
比如上面的這個(gè)圖,我們有的只是真實(shí)采集而來(lái)的人臉樣本數(shù)據(jù)集,僅此而已,而且很關(guān)鍵的一點(diǎn)是我們連人臉數(shù)據(jù)集的類(lèi)標(biāo)簽都沒(méi)有,也就是我們不知道那個(gè)人臉對(duì)應(yīng)的是誰(shuí)。
我們要得到什么?
至于要得到什么,不同的任務(wù)得到的東西不一樣,我們只說(shuō)最原始的GAN目的,那就是我們想通過(guò)輸入一個(gè)噪聲,模擬得到一個(gè)人臉圖像,這個(gè)圖像可以非常逼真以至于以假亂真。
好了再來(lái)理解下GAN的兩個(gè)模型要做什么。
首先判別模型,就是圖中右半部分的網(wǎng)絡(luò),直觀來(lái)看就是一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),輸入就是一副圖像,輸出就是一個(gè)概率值,用于判斷真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不過(guò)是人們定義的概率而已。
其次是生成模型,生成模型要做什么呢,同樣也可以看成是一個(gè)神經(jīng)網(wǎng)絡(luò)模型,輸入是一組隨機(jī)數(shù)Z,輸出是一個(gè)圖像,不再是一個(gè)數(shù)值而已。從圖中可以看到,會(huì)存在兩個(gè)數(shù)據(jù)集,一個(gè)是真實(shí)數(shù)據(jù)集,這好說(shuō),另一個(gè)是假的數(shù)據(jù)集,那這個(gè)數(shù)據(jù)集就是有生成網(wǎng)絡(luò)造出來(lái)的數(shù)據(jù)集。好了根據(jù)這個(gè)圖我們?cè)賮?lái)理解一下GAN的目標(biāo)是要干什么:
判別網(wǎng)絡(luò)的目的:就是能判別出來(lái)屬于的一張圖它是來(lái)自真實(shí)樣本集還是假樣本集。假如輸入的是真樣本,網(wǎng)絡(luò)輸出就接近1,輸入的是假樣本,網(wǎng)絡(luò)輸出接近0,那么很完美,達(dá)到了很好判別的目的。
生成網(wǎng)絡(luò)的目的:生成網(wǎng)絡(luò)是造樣本的,它的目的就是使得自己造樣本的能力盡可能強(qiáng),強(qiáng)到什么程度呢,你判別網(wǎng)絡(luò)沒(méi)法判斷我是真樣本還是假樣本。
有了這個(gè)理解我們?cè)賮?lái)看看為什么叫做對(duì)抗網(wǎng)絡(luò)了。判別網(wǎng)絡(luò)說(shuō),我很強(qiáng),來(lái)一個(gè)樣本我就知道它是來(lái)自真樣本集還是假樣本集。生成網(wǎng)絡(luò)就不服了,說(shuō)我也很強(qiáng),我生成一個(gè)假樣本,雖然我生成網(wǎng)絡(luò)知道是假的,但是你判別網(wǎng)絡(luò)不知道呀,我包裝的非常逼真,以至于判別網(wǎng)絡(luò)無(wú)法判斷真假,那么用輸出數(shù)值來(lái)解釋就是,生成網(wǎng)絡(luò)生成的假樣本進(jìn)去了判別網(wǎng)絡(luò)以后,判別網(wǎng)絡(luò)給出的結(jié)果是一個(gè)接近0.5的值,極限情況就是0.5,也就是說(shuō)判別不出來(lái)了,這就是納什平衡了。
由這個(gè)分析可以發(fā)現(xiàn),生成網(wǎng)絡(luò)與判別網(wǎng)絡(luò)的目的正好是相反的,一個(gè)說(shuō)我能判別的好,一個(gè)說(shuō)我讓你判別不好。所以叫做對(duì)抗,叫做博弈。那么最后的結(jié)果到底是誰(shuí)贏呢?這就要?dú)w結(jié)到設(shè)計(jì)者,也就是我們希望誰(shuí)贏了。作為設(shè)計(jì)者的我們,我們的目的是要得到以假亂真的樣本,那么很自然的我們希望生成樣本贏了,也就是希望生成樣本很真,判別網(wǎng)絡(luò)能力不足以區(qū)分真假樣本位置。
▌再理解
知道了GAN大概的目的與設(shè)計(jì)思路,那么一個(gè)很自然的問(wèn)題來(lái)了就是我們?cè)撊绾斡脭?shù)學(xué)方法解決這么一個(gè)對(duì)抗問(wèn)題。這就涉及到如何訓(xùn)練這樣一個(gè)生成對(duì)抗網(wǎng)絡(luò)模型了,還是先上一個(gè)圖,用圖來(lái)解釋最直接:
需要注意的是生成模型與對(duì)抗模型可以說(shuō)是完全獨(dú)立的兩個(gè)模型,好比就是完全獨(dú)立的兩個(gè)神經(jīng)網(wǎng)絡(luò)模型,他們之間沒(méi)有什么聯(lián)系。
好了那么訓(xùn)練這樣的兩個(gè)模型的大方法就是:單獨(dú)交替迭代訓(xùn)練。
什么意思?因?yàn)槭?個(gè)網(wǎng)絡(luò),不好一起訓(xùn)練,所以才去交替迭代訓(xùn)練,我們一一來(lái)看。
假設(shè)現(xiàn)在生成網(wǎng)絡(luò)模型已經(jīng)有了(當(dāng)然可能不是最好的生成網(wǎng)絡(luò)),那么給一堆隨機(jī)數(shù)組,就會(huì)得到一堆假的樣本集(因?yàn)椴皇亲罱K的生成模型,那么現(xiàn)在生成網(wǎng)絡(luò)可能就處于劣勢(shì),導(dǎo)致生成的樣本就不咋地,可能很容易就被判別網(wǎng)絡(luò)判別出來(lái)了說(shuō)這貨是假冒的),但是先不管這個(gè),假設(shè)我們現(xiàn)在有了這樣的假樣本集,真樣本集一直都有,現(xiàn)在我們?nèi)藶榈亩x真假樣本集的標(biāo)簽,因?yàn)槲覀兿M鏄颖炯妮敵霰M可能為1,假樣本集為0,很明顯這里我們就已經(jīng)默認(rèn)真樣本集所有的類(lèi)標(biāo)簽都為1,而假樣本集的所有類(lèi)標(biāo)簽都為0.
有人會(huì)說(shuō),在真樣本集里面的人臉中,可能張三人臉和李四人臉不一樣呀,對(duì)于這個(gè)問(wèn)題我們需要理解的是,我們現(xiàn)在的任務(wù)是什么,我們是想分樣本真假,而不是分真樣本中那個(gè)是張三label、那個(gè)是李四label。況且我們也知道,原始真樣本的label我們是不知道的。回過(guò)頭來(lái),我們現(xiàn)在有了真樣本集以及它們的label(都是1)、假樣本集以及它們的label(都是0),這樣單就判別網(wǎng)絡(luò)來(lái)說(shuō),此時(shí)問(wèn)題就變成了一個(gè)再簡(jiǎn)單不過(guò)的有監(jiān)督的二分類(lèi)問(wèn)題了,直接送到神經(jīng)網(wǎng)絡(luò)模型中訓(xùn)練就完事了。假設(shè)訓(xùn)練完了,下面我們來(lái)看生成網(wǎng)絡(luò)。
對(duì)于生成網(wǎng)絡(luò),想想我們的目的,是生成盡可能逼真的樣本。那么原始的生成網(wǎng)絡(luò)生成的樣本你怎么知道它真不真呢?就是送到判別網(wǎng)絡(luò)中,所以在訓(xùn)練生成網(wǎng)絡(luò)的時(shí)候,我們需要聯(lián)合判別網(wǎng)絡(luò)一起才能達(dá)到訓(xùn)練的目的。什么意思?就是如果我們單單只用生成網(wǎng)絡(luò),那么想想我們?cè)趺慈ビ?xùn)練?誤差來(lái)源在哪里?細(xì)想一下沒(méi)有,但是如果我們把剛才的判別網(wǎng)絡(luò)串接在生成網(wǎng)絡(luò)的后面,這樣我們就知道真假了,也就有了誤差了。所以對(duì)于生成網(wǎng)絡(luò)的訓(xùn)練其實(shí)是對(duì)生成-判別網(wǎng)絡(luò)串接的訓(xùn)練,就像圖中顯示的那樣。好了那么現(xiàn)在來(lái)分析一下樣本,原始的噪聲數(shù)組Z我們有,也就是生成了假樣本我們有,此時(shí)很關(guān)鍵的一點(diǎn)來(lái)了,我們要把這些假樣本的標(biāo)簽都設(shè)置為1,也就是認(rèn)為這些假樣本在生成網(wǎng)絡(luò)訓(xùn)練的時(shí)候是真樣本。
那么為什么要這樣呢?我們想想,是不是這樣才能起到迷惑判別器的目的,也才能使得生成的假樣本逐漸逼近為正樣本。好了,重新順一下思路,現(xiàn)在對(duì)于生成網(wǎng)絡(luò)的訓(xùn)練,我們有了樣本集(只有假樣本集,沒(méi)有真樣本集),有了對(duì)應(yīng)的label(全為1),是不是就可以訓(xùn)練了?有人會(huì)問(wèn),這樣只有一類(lèi)樣本,訓(xùn)練啥呀?誰(shuí)說(shuō)一類(lèi)樣本就不能訓(xùn)練了?只要有誤差就行。還有人說(shuō),你這樣一訓(xùn)練,判別網(wǎng)絡(luò)的網(wǎng)絡(luò)參數(shù)不是也跟著變嗎?沒(méi)錯(cuò),這很關(guān)鍵,所以在訓(xùn)練這個(gè)串接的網(wǎng)絡(luò)的時(shí)候,一個(gè)很重要的操作就是不要判別網(wǎng)絡(luò)的參數(shù)發(fā)生變化,也就是不讓它參數(shù)發(fā)生更新,只是把誤差一直傳,傳到生成網(wǎng)絡(luò)那塊后更新生成網(wǎng)絡(luò)的參數(shù)。這樣就完成了生成網(wǎng)絡(luò)的訓(xùn)練了。
在完成生成網(wǎng)絡(luò)訓(xùn)練好,那么我們是不是可以根據(jù)目前新的生成網(wǎng)絡(luò)再對(duì)先前的那些噪聲Z生成新的假樣本了,沒(méi)錯(cuò),并且訓(xùn)練后的假樣本應(yīng)該是更真了才對(duì)。然后又有了新的真假樣本集(其實(shí)是新的假樣本集),這樣又可以重復(fù)上述過(guò)程了。我們把這個(gè)過(guò)程稱(chēng)作為單獨(dú)交替訓(xùn)練。我們可以實(shí)現(xiàn)定義一個(gè)迭代次數(shù),交替迭代到一定次數(shù)后停止即可。這個(gè)時(shí)候我們?cè)偃タ匆豢丛肼昛生成的假樣本會(huì)發(fā)現(xiàn),原來(lái)它已經(jīng)很真了。
看完了這個(gè)過(guò)程是不是感覺(jué)GAN的設(shè)計(jì)真的很巧妙,個(gè)人覺(jué)得最值得稱(chēng)贊的地方可能在于這種假樣本在訓(xùn)練過(guò)程中的真假變換,這也是博弈得以進(jìn)行的關(guān)鍵之處。
▌進(jìn)一步
文字的描述相信已經(jīng)讓大多數(shù)的人知道了這個(gè)過(guò)程,下面我們來(lái)看看原文中幾個(gè)重要的數(shù)學(xué)公式描述,首先我們直接上原始論文中的目標(biāo)公式吧:
上述這個(gè)公式說(shuō)白了就是一個(gè)最大最小優(yōu)化問(wèn)題,其實(shí)對(duì)應(yīng)的也就是上述的兩個(gè)優(yōu)化過(guò)程。有人說(shuō)如果不看別的,能達(dá)看到這個(gè)公式就拍案叫絕的地步,那就是機(jī)器學(xué)習(xí)的頂級(jí)專(zhuān)家,哈哈,真是前路漫漫。同時(shí)也說(shuō)明這個(gè)簡(jiǎn)單的公式意義重大。
這個(gè)公式既然是最大最小的優(yōu)化,那就不是一步完成的,其實(shí)對(duì)比我們的分析過(guò)程也是這樣的,這里現(xiàn)優(yōu)化D,然后在取優(yōu)化G,本質(zhì)上是兩個(gè)優(yōu)化問(wèn)題,把拆解就如同下面兩個(gè)公式:
優(yōu)化D:
優(yōu)化G:
可以看到,優(yōu)化D的時(shí)候,也就是判別網(wǎng)絡(luò),其實(shí)沒(méi)有生成網(wǎng)絡(luò)什么事,后面的G(z)這里就相當(dāng)于已經(jīng)得到的假樣本。優(yōu)化D的公式的第一項(xiàng),使的真樣本x輸入的時(shí)候,得到的結(jié)果越大越好,可以理解,因?yàn)樾枰鏄颖镜念A(yù)測(cè)結(jié)果越接近于1越好嘛。對(duì)于假樣本,需要優(yōu)化是的其結(jié)果越小越好,也就是D(G(z))越小越好,因?yàn)樗臉?biāo)簽為0。但是呢第一項(xiàng)是越大,第二項(xiàng)是越小,這不矛盾了,所以呢把第二項(xiàng)改成1-D(G(z)),這樣就是越大越好,兩者合起來(lái)就是越大越好。 那么同樣在優(yōu)化G的時(shí)候,這個(gè)時(shí)候沒(méi)有真樣本什么事,所以把第一項(xiàng)直接卻掉了。這個(gè)時(shí)候只有假樣本,但是我們說(shuō)這個(gè)時(shí)候是希望假樣本的標(biāo)簽是1的,所以是D(G(z))越大越好,但是呢為了統(tǒng)一成1-D(G(z))的形式,那么只能是最小化1-D(G(z)),本質(zhì)上沒(méi)有區(qū)別,只是為了形式的統(tǒng)一。之后這兩個(gè)優(yōu)化模型可以合并起來(lái)寫(xiě),就變成了最開(kāi)始的那個(gè)最大最小目標(biāo)函數(shù)了。
所以回過(guò)頭來(lái)我們來(lái)看這個(gè)最大最小目標(biāo)函數(shù),里面包含了判別模型的優(yōu)化,包含了生成模型的以假亂真的優(yōu)化,完美的闡釋了這樣一個(gè)優(yōu)美的理論。
▌再進(jìn)一步
有人說(shuō)GAN強(qiáng)大之處在于可以自動(dòng)的學(xué)習(xí)原始真實(shí)樣本集的數(shù)據(jù)分布,不管這個(gè)分布多么的復(fù)雜,只要訓(xùn)練的足夠好就可以學(xué)出來(lái)。針對(duì)這一點(diǎn),感覺(jué)有必要好好理解一下為什么別人會(huì)這么說(shuō)。
我們知道,傳統(tǒng)的機(jī)器學(xué)習(xí)方法,我們一般都會(huì)定義一個(gè)什么模型讓數(shù)據(jù)去學(xué)習(xí)。比如說(shuō)假設(shè)我們知道原始數(shù)據(jù)屬于高斯分布呀,只是不知道高斯分布的參數(shù),這個(gè)時(shí)候我們定義高斯分布,然后利用數(shù)據(jù)去學(xué)習(xí)高斯分布的參數(shù)得到我們最終的模型。再比如說(shuō)我們定義一個(gè)分類(lèi)器,比如SVM,然后強(qiáng)行讓數(shù)據(jù)進(jìn)行東變西變,進(jìn)行各種高維映射,最后可以變成一個(gè)簡(jiǎn)單的分布,SVM可以很輕易的進(jìn)行二分類(lèi)分開(kāi),其實(shí)SVM已經(jīng)放松了這種映射關(guān)系了,但是也是給了一個(gè)模型,這個(gè)模型就是核映射(什么徑向基函數(shù)等等),說(shuō)白了其實(shí)也好像是你事先知道讓數(shù)據(jù)該怎么映射一樣,只是核映射的參數(shù)可以學(xué)習(xí)罷了。
所有的這些方法都在直接或者間接的告訴數(shù)據(jù)你該怎么映射一樣,只是不同的映射方法能力不一樣。那么我們?cè)賮?lái)看看GAN,生成模型最后可以通過(guò)噪聲生成一個(gè)完整的真實(shí)數(shù)據(jù)(比如人臉),說(shuō)明生成模型已經(jīng)掌握了從隨機(jī)噪聲到人臉數(shù)據(jù)的分布規(guī)律了,有了這個(gè)規(guī)律,想生成人臉還不容易。然而這個(gè)規(guī)律我們開(kāi)始知道嗎?顯然不知道,如果讓你說(shuō)從隨機(jī)噪聲到人臉應(yīng)該服從什么分布,你不可能知道。這是一層層映射之后組合起來(lái)的非常復(fù)雜的分布映射規(guī)律。然而GAN的機(jī)制可以學(xué)習(xí)到,也就是說(shuō)GAN學(xué)習(xí)到了真實(shí)樣本集的數(shù)據(jù)分布。
再拿原論文中的一張圖來(lái)解釋:
這張圖表明的是GAN的生成網(wǎng)絡(luò)如何一步步從均勻分布學(xué)習(xí)到正太分布的。原始數(shù)據(jù)x服從正太分布,這個(gè)過(guò)程你也沒(méi)告訴生成網(wǎng)絡(luò)說(shuō)你得用正太分布來(lái)學(xué)習(xí),但是生成網(wǎng)絡(luò)學(xué)習(xí)到了。假設(shè)你改一下x的分布,不管什么分布,生成網(wǎng)絡(luò)可能也能學(xué)到。這就是GAN可以自動(dòng)學(xué)習(xí)真實(shí)數(shù)據(jù)的分布的強(qiáng)大之處。
還有人說(shuō)GAN強(qiáng)大之處在于可以自動(dòng)的定義潛在損失函數(shù)。 什么意思呢,這應(yīng)該說(shuō)的是判別網(wǎng)絡(luò)可以自動(dòng)學(xué)習(xí)到一個(gè)好的判別方法,其實(shí)就是等效的理解為可以學(xué)習(xí)到好的損失函數(shù),來(lái)比較好或者不好的判別出來(lái)結(jié)果。雖然大的loss函數(shù)還是我們?nèi)藶槎x的,基本上對(duì)于多數(shù)GAN也都這么定義就可以了,但是判別網(wǎng)絡(luò)潛在學(xué)習(xí)到的損失函數(shù)隱藏在網(wǎng)絡(luò)之中,不同的問(wèn)題這個(gè)函數(shù)就不一樣,所以說(shuō)可以自動(dòng)學(xué)習(xí)這個(gè)潛在的損失函數(shù)。
▌開(kāi)始做小實(shí)驗(yàn)
本節(jié)主要實(shí)驗(yàn)一下如何通過(guò)隨機(jī)數(shù)組生成mnist圖像。mnist手寫(xiě)體數(shù)據(jù)庫(kù)應(yīng)該都熟悉的。這里簡(jiǎn)單的使用matlab來(lái)實(shí)現(xiàn),方便看到整個(gè)實(shí)現(xiàn)過(guò)程。這里用到了一個(gè)工具箱 DeepLearnToolbox。
網(wǎng)絡(luò)結(jié)構(gòu)很簡(jiǎn)單,就定義成下面這樣子:
將上述工具箱添加到路徑,然后運(yùn)行下面代碼:
clc clear %% 構(gòu)造真實(shí)訓(xùn)練樣本 60000個(gè)樣本 1*784維(28*28展開(kāi)) load mnist_uint8; train_x = double(train_x(1:60000,:)) / 255; % 真實(shí)樣本認(rèn)為為標(biāo)簽 [1 0]; 生成樣本為[0 1]; train_y = double(ones(size(train_x,1),1)); % normalize train_x = mapminmax(train_x, 0, 1); rand('state',0) %% 構(gòu)造模擬訓(xùn)練樣本 60000個(gè)樣本 1*100維 test_x = normrnd(0,1,[60000,100]); % 0-255的整數(shù) test_x = mapminmax(test_x, 0, 1); test_y = double(zeros(size(test_x,1),1)); test_y_rel = double(ones(size(test_x,1),1)); %% nn_G_t = nnsetup([100 784]); nn_G_t.activation_function = 'sigm'; nn_G_t.output = 'sigm'; nn_D = nnsetup([784 100 1]); nn_D.weightPenaltyL2 = 1e-4; % L2 weight decay nn_D.dropoutFraction = 0.5; % Dropout fraction nn_D.learningRate = 0.01; % Sigm require a lower learning rate nn_D.activation_function = 'sigm'; nn_D.output = 'sigm'; % nn_D.weightPenaltyL2 = 1e-4; % L2 weight decay nn_G = nnsetup([100 784 100 1]); nn_G.weightPenaltyL2 = 1e-4; % L2 weight decay nn_G.dropoutFraction = 0.5; % Dropout fraction nn_G.learningRate = 0.01; % Sigm require a lower learning rate nn_G.activation_function = 'sigm'; nn_G.output = 'sigm'; % nn_G.weightPenaltyL2 = 1e-4; % L2 weight decay opts.numepochs = 1; % Number of full sweeps through data opts.batchsize = 100; % Take a mean gradient step over this many samples %% num = 1000; tic for each = 1:1500%----------計(jì)算G的輸出:假樣本------------------- for i = 1:length(nn_G_t.W) %共享網(wǎng)絡(luò)參數(shù)nn_G_t.W{i} = nn_G.W{i};endG_output = nn_G_out(nn_G_t, test_x);%-----------訓(xùn)練D------------------------------index = randperm(60000);train_data_D = [train_x(index(1:num),:);G_output(index(1:num),:)];train_y_D = [train_y(index(1:num),:);test_y(index(1:num),:)];nn_D = nntrain(nn_D, train_data_D, train_y_D, opts);%訓(xùn)練D%-----------訓(xùn)練G-------------------------------for i = 1:length(nn_D.W) %共享訓(xùn)練的D的網(wǎng)絡(luò)參數(shù)nn_G.W{length(nn_G.W)-i+1} = nn_D.W{length(nn_D.W)-i+1};end%訓(xùn)練G:此時(shí)假樣本標(biāo)簽為1,認(rèn)為是真樣本nn_G = nntrain(nn_G, test_x(index(1:num),:), test_y_rel(index(1:num),:), opts); end toc for i = 1:length(nn_G_t.W)nn_G_t.W{i} = nn_G.W{i}; end fin_output = nn_G_out(nn_G_t, test_x);函數(shù)nn_G_out為:
function output = nn_G_out(nn, x)nn.testing = 1;nn = nnff(nn, x, zeros(size(x,1), nn.size(end)));nn.testing = 0;output = nn.a{end}; end看一下這個(gè)及其簡(jiǎn)單的函數(shù),其實(shí)最值得注意的就是中間那個(gè)交替訓(xùn)練的過(guò)程,這里我分了三步列出來(lái):
重新計(jì)算假樣本(假樣本每次是需要更新的,產(chǎn)生越來(lái)越像的樣本)
訓(xùn)練D網(wǎng)絡(luò),一個(gè)二分類(lèi)的神經(jīng)網(wǎng)絡(luò);
訓(xùn)練G網(wǎng)絡(luò),一個(gè)串聯(lián)起來(lái)的長(zhǎng)網(wǎng)絡(luò),也是一個(gè)二分類(lèi)的神經(jīng)網(wǎng)絡(luò)(不過(guò)只有假樣本來(lái)訓(xùn)練),同時(shí)D部分參數(shù)在下一次的時(shí)候不能變了。
就這樣調(diào)一調(diào)參數(shù),最終輸出在fin_output里面,多運(yùn)行幾次顯示不同運(yùn)行次數(shù)下的結(jié)果:
可以看到的是結(jié)果還是有點(diǎn)像模像樣的。
▌實(shí)驗(yàn)總結(jié)
運(yùn)行上述簡(jiǎn)單的網(wǎng)絡(luò)我發(fā)現(xiàn)幾個(gè)問(wèn)題:
網(wǎng)絡(luò)存在著不收斂問(wèn)題;網(wǎng)絡(luò)不穩(wěn)定;網(wǎng)絡(luò)難訓(xùn)練;讀過(guò)原論文其實(shí)作者也提到過(guò)這些問(wèn)題,包括GAN剛出來(lái)的時(shí)候,很多人也在致力于解決這些問(wèn)題,當(dāng)你實(shí)驗(yàn)自己碰到的時(shí)候,還是很有意思的。那么這些問(wèn)題怎么體現(xiàn)的呢,舉個(gè)例子,可能某一次你會(huì)發(fā)現(xiàn)訓(xùn)練的誤差很小,在下一代訓(xùn)練時(shí),馬上又出現(xiàn)極限性的上升的很厲害,過(guò)幾代又發(fā)現(xiàn)訓(xùn)練誤差很小,震蕩太嚴(yán)重。
其次網(wǎng)絡(luò)需要調(diào)才能出像樣的結(jié)果。交替迭代次數(shù)的不同結(jié)果也不一樣。比如每一代訓(xùn)練中,D網(wǎng)絡(luò)訓(xùn)練2回,G網(wǎng)絡(luò)訓(xùn)練一回,結(jié)果就不一樣。
這是簡(jiǎn)單的無(wú)條件GAN,所以每一代訓(xùn)練完后,只能出現(xiàn)一個(gè)結(jié)果,那就是0-9中的某一個(gè)數(shù)。要想在一代訓(xùn)練中出現(xiàn)好幾種結(jié)果,就需要使用到條件GAN了。
▌再提升一下
圖像生成
這里嘗試給「圖像生成」一個(gè)大致定義:圖像生成的目的是,學(xué)習(xí)一個(gè)生成模型,能夠?qū)?lái)自于輸入分布的一幅圖像或變量轉(zhuǎn)變成為一幅輸出圖像。這里,我們不僅要求「輸入」?jié)M足一個(gè)輸入分布,同樣,我們還要求「輸出」?jié)M足一個(gè)預(yù)期的期望分布。通過(guò)定義不同的輸入分布和期望分布,就對(duì)應(yīng)著不同的圖像生成問(wèn)題。
一開(kāi)始,最標(biāo)準(zhǔn)的 GAN 的假設(shè)是,輸入要服從隨機(jī)噪聲分布,期望分布是所有的真實(shí)圖像。這個(gè)問(wèn)題一開(kāi)始定義得太大,所以雖然 GAN 在2014年就出現(xiàn)了,在2014年到2016年這段時(shí)間其實(shí)發(fā)展得并不快。
后來(lái)大家就去思考,輸入的分布也可以不是隨機(jī)分布,于是大家開(kāi)始根據(jù)各種實(shí)際問(wèn)題的需要來(lái)定義自己需要的輸入分布和期望分布。比如,輸入分布可以是來(lái)自于所有斑馬的一幅圖像,輸出分布是所有正常馬的圖像,這樣系統(tǒng)要學(xué)習(xí)的其實(shí)是這兩種圖像之間的映射(mapping)。
同樣,如果我們輸入的是一個(gè)低分辨率圖像,輸出的是一個(gè)高分辨率圖像,那么希望系統(tǒng)學(xué)習(xí)到的是低分辨率和高分辨率之間的映射。去區(qū)塊(deblocking),輸入的是 JPEG 壓縮圖像,輸出的是真實(shí)高清圖像,我們也是希望學(xué)到兩者之間的映射。人臉領(lǐng)域也是一樣,比如我們做的超分辨和性別轉(zhuǎn)換。輸入是男性圖像,輸出是女性圖像,學(xué)習(xí)二者之間的映射。
另外一個(gè)有意思的是圖像文本描述的自動(dòng)生成(Image captioning),輸入的是圖像,輸出的是句子。大家以前就認(rèn)為,這是一個(gè)一對(duì)一的映射,其實(shí)不是。它實(shí)際上是一對(duì)多的映射。不同的人來(lái)描述一幅圖,就會(huì)產(chǎn)生不同的語(yǔ)句。所以如果用 GAN 來(lái)做這件事情,應(yīng)該是很有趣的,今年有幾篇投 ICCV 的文章做的就是這方面的工作。
關(guān)于GAN的三個(gè)問(wèn)題
第一,度量復(fù)雜分布之間差異性。我們希望輸出分布達(dá)到期望分布,那么我們需要找兩個(gè)分布之間差異的度量方式,這是我認(rèn)為在 GAN 里面需要研究的第一個(gè)關(guān)鍵性問(wèn)題。
第二,如何設(shè)計(jì)生成器。如果我們想要學(xué)習(xí)映射,就需要一個(gè)生成器,那么就要對(duì)它的訓(xùn)練、可學(xué)習(xí)性進(jìn)行設(shè)計(jì)。這是 GAN 里面另一個(gè)可以研究的角度。
第三,連接輸入和輸出。如下圖右邊性別轉(zhuǎn)換的例子,輸入是一張男性圖像,輸出是一張女性圖像。顯然我們需要的并不是從輸入到任意一幅女性人臉圖像的映射,二是要求輸出的女性圖像要跟輸入的男性圖像盡可能像,這個(gè)轉(zhuǎn)換才是有意義的。所以,這就是 GAN 里面另外一個(gè)重要的研究方向,就是如何將輸入和輸出連接起來(lái)。
下面針對(duì)這三個(gè)問(wèn)題,進(jìn)行詳細(xì)的講解。
如何度量?jī)蓚€(gè)分布之間的差異性
GAN 使用了一個(gè)分類(lèi)器來(lái)度量輸出分布和期望分布的差異性。實(shí)際上,Torralba和Efros在 2011 年的時(shí)候也考慮過(guò)用一個(gè)分類(lèi)器來(lái)分析兩個(gè)分布之間差異,這也是當(dāng)時(shí)做 domain adaption 的學(xué)者喜歡引用的一篇論文。他們?cè)O(shè)計(jì)了一個(gè)實(shí)驗(yàn),給你三張圖像,讓你猜是來(lái)自 12 個(gè)數(shù)據(jù)集(包括 ImageNet、COCO和PASCAL VOC等)中的那一個(gè)。如果是隨機(jī)猜的話,顯然猜中的概率是 1/12。但是人猜中的準(zhǔn)確度往往能達(dá)到 30% 左右,說(shuō)明不同數(shù)據(jù)集刻畫(huà)的分布是不一致的。這里人其實(shí)可視為一個(gè)分類(lèi)器,通過(guò)判斷樣本來(lái)自于那個(gè)數(shù)據(jù)集來(lái)分析兩個(gè)分布之間差異。
雖然 NIPS 2014 年的這篇 GAN 論文沒(méi)有引用 Torralba 的工作,其實(shí)它也是采用了一個(gè)判別器來(lái)度量?jī)蓚€(gè)分布的差異化程度。基本的過(guò)程是,固定生成器,得到一個(gè)最好的判別器,再固定判別器,學(xué)到一個(gè)最好的生成器。但是其中有一個(gè)最令人擔(dān)心的問(wèn)題,那就是如果我們學(xué)到的是一個(gè)很復(fù)雜的分布,就會(huì)出現(xiàn)模式崩潰(Mode Collapse)的問(wèn)題,即無(wú)法學(xué)習(xí)復(fù)雜分布的全局,只能學(xué)習(xí)其中的一部分。
對(duì)此,最早的解決方案,是調(diào)整生成器(G)和判別器(D)的優(yōu)化次序,但這也不是一個(gè)終極方案。從去年開(kāi)始大家開(kāi)始關(guān)注要去找到一個(gè)終極解決方案。
那之前,大家怎么去解決這個(gè)問(wèn)題呢?使用的是原來(lái)機(jī)器學(xué)習(xí)里常用的方法:最大化均值差異(Maximum Mean Discrepancy,MMD)。
如果兩個(gè)分布相同的話,那么兩個(gè)分布的數(shù)學(xué)期望顯然也應(yīng)該相同;然而,如果兩個(gè)分布的數(shù)學(xué)期望相同,并不能保證兩個(gè)分布相同。因而,我們需要更好地建立「分布相同」和「期望相同」之間的連接關(guān)系。幸運(yùn)的是,我們可以對(duì)來(lái)自于兩個(gè)分布的變量施加同樣的非線性變換。如果對(duì)于所有的非線性變換下兩個(gè)分布的數(shù)學(xué)期望均相同(即:兩個(gè)分布的期望的最大差別為0),在統(tǒng)計(jì)學(xué)意義上就可以保證兩個(gè)分布是相同的。不幸的是,這種方法需要我們遍歷所有的非線性變換,從實(shí)踐的角度似乎任由一定難度。一開(kāi)始,在機(jī)器學(xué)習(xí)領(lǐng)域,大家傾向于用線性 kernel或Gaussian RBF kernel 來(lái)進(jìn)行非線性變換,后來(lái)開(kāi)始采用 multi-kernel。從去年開(kāi)始,大家開(kāi)始用 CNN 來(lái)近似所有的非線性變換,在 MMD 框架下進(jìn)行圖像生成。首先,固定生成器并最大化 MMD,然后固定判別器里 MMD 的 f,然后通過(guò)最小化 MMD 來(lái)更新生成器。
最常用的一種方法,就是拿 MMD 來(lái)代替判別器,去學(xué)習(xí)一個(gè) CNN,這是 ICML 2015 的一篇文章中嘗試的方法,我們自己也在這個(gè)基礎(chǔ)上做了一些工作。
但實(shí)際上,如果直接拿 MMD 去替換生成器,雖然有一定效果,但不是特別成功。所以,從 NIPS 2016 開(kāi)始,就出現(xiàn)了一個(gè) Improved GAN,這個(gè)工作雖然沒(méi)有引用 MMD 的論文,但實(shí)際上在更新判別器的同時(shí)也最小化了 MMD。等到了 Wasserstein GAN 的時(shí)候,它就明確解釋了與 MMD 之間的聯(lián)系,雖然論文里寫(xiě)的是一個(gè)「減」的關(guān)系,但我們看它的代碼,它也是要加上一個(gè)范數(shù)的,因?yàn)橹皇亲寖蓚€(gè)分布的期望最大化或最小化都不能保證分布的差異化程度最小。
然后,最近 ICLR 2017 的一篇論文也明確指出要用 MMD 來(lái)作為 GAN 網(wǎng)絡(luò)的停止條件和學(xué)習(xí)效果的評(píng)價(jià)手段。
如何設(shè)計(jì)一個(gè)生成器
這個(gè)部分相對(duì)來(lái)說(shuō)就比較容易一些。早期的時(shí)候,GAN 的一個(gè)最大的進(jìn)步就是 DCGAN,用于圖像生成時(shí),比較合適的選擇就是用全卷積網(wǎng)絡(luò)加上 batch normalizaiton。
對(duì)于復(fù)雜的圖像生成,可以使用分階段的方式。比如,第一步可以生成小圖,然后由小圖生成大圖。沿著這個(gè)方向,香港中文大學(xué)王曉剛老師和康奈爾大學(xué) John Hopcroft 都做了一些工作。
對(duì)于圖像增強(qiáng)(image enhancement)相關(guān)的一些任務(wù),包括超分辨率和人臉屬性轉(zhuǎn)換(Face Attribute Transfer),目前在有監(jiān)督時(shí)表現(xiàn)最好網(wǎng)絡(luò)是 ResNet ,所以我們?cè)谶@些任務(wù)中實(shí)用GAN時(shí)一般也會(huì)采用 ResNet 結(jié)構(gòu)。
同樣,對(duì)于圖像轉(zhuǎn)換(image translation),基本上用的是 U-Net 結(jié)構(gòu)。我們?cè)谧龌谝龑?dǎo)圖像的人臉填充(guided face completion)時(shí)也采用了 U-Net 結(jié)構(gòu)。
對(duì)于圖像文本描述的自動(dòng)生成,顯然應(yīng)該采用 CNN+RNN 這樣的網(wǎng)絡(luò)結(jié)構(gòu)。總而言之,一個(gè)比較好的建議就是根據(jù)任務(wù)的特點(diǎn)和前任的經(jīng)驗(yàn)來(lái)設(shè)計(jì)生成器網(wǎng)絡(luò)。
如何連接輸入和輸出
如何通過(guò)連接輸入和輸出的方式來(lái)改善 GAN 的可學(xué)習(xí)性,這個(gè)問(wèn)題是從 NIPS 2016 開(kāi)始得到了較多的關(guān)注,同時(shí)這也是我自己非常感興趣的一個(gè)方向。比較早的一個(gè)工作就是 InfoGAN,其特點(diǎn)就是輸入包括兩個(gè)部分:C(隱變量)和 Z(噪聲)。InfoGAN 生成圖像之后,不僅要求生成圖像和真實(shí)圖像難以區(qū)分,還要求能夠從生成圖像中預(yù)測(cè)出 C,這樣就為輸入和輸出建立起了一個(gè)聯(lián)系。
另外針對(duì)一些任務(wù),比如超分辨,可以用 Perceptual loss 的方式來(lái)建立輸入和輸出的聯(lián)系。
我們?cè)谧鋈四槍傩赞D(zhuǎn)換時(shí)發(fā)現(xiàn),現(xiàn)有的 Perceptual loss 往往是定義在一個(gè)現(xiàn)有的網(wǎng)絡(luò)基礎(chǔ)上的,我們就想能不能把 Perceptual loss 網(wǎng)絡(luò)和判別器結(jié)合起來(lái),所以就提出了一個(gè) Adaptive perceptual loss。結(jié)果表明Adaptive perceptual loss能夠具有更好的自適應(yīng)性,能夠更好地建立輸入和輸出的聯(lián)系和顯著改善生成圖片的視覺(jué)效果。
當(dāng)輸入和輸出都是已知時(shí)(比如圖像超分辨和圖像轉(zhuǎn)換),要用什么方式來(lái)連接輸入和輸出呢?以前是用 Perceptual loss 來(lái)連,現(xiàn)在更好的方式是用 Conditional GAN。假設(shè)有一個(gè) Positive Pair(輸入和groundtruth圖像)和 Negative Pair(輸入和生成圖像),那么判別器就不是在兩幅圖像之間做判別,而是在兩個(gè)「Pair」之間做判別。這樣的話,輸入就很自然地引入到了判別器中。
在此基礎(chǔ)上,我們還考慮了當(dāng)有一些額外的 Guidance 時(shí),如何來(lái)更好地建立輸入和輸出的聯(lián)系。
上面提到,在有監(jiān)督的情況下 Conditional GAN 是一個(gè)比較好的選擇。但如果在 unpair 的情況下做圖像轉(zhuǎn)換,要如何建立輸入和輸出的聯(lián)系?譚平老師他們組和Efros組今年就做了這方面的工作,其實(shí)去年投 CVPR2017 的一篇論文也做了類(lèi)似的工作。我們知道,由于是unpair的,原則上訓(xùn)練階段輸入和輸出不能直接建立聯(lián)系。這時(shí)他們采用的是一種 Cycle-Consistent 的方式。從 X 可以預(yù)測(cè)和生成 Y,再?gòu)?Y 重新生成 X',那么由 Y 生成的 X' 就能跟輸入的 X 建立聯(lián)系。這樣的話,我們實(shí)際上相當(dāng)于隱式地建立了從 X 到 Y 的聯(lián)系。
總結(jié)
現(xiàn)在的GAN已經(jīng)到了五花八門(mén)的時(shí)候了,各種GAN應(yīng)用也很多,理解底層原理再慢慢往上層擴(kuò)展。GAN還是一個(gè)很厲害的東西,它使得現(xiàn)有問(wèn)題從有監(jiān)督學(xué)習(xí)慢慢過(guò)渡到無(wú)監(jiān)督學(xué)習(xí),而無(wú)監(jiān)督學(xué)習(xí)才是自然界中普遍存在的,因?yàn)楹芏鄷r(shí)候沒(méi)有辦法拿到監(jiān)督信息的。要不Yann Lecun贊嘆GAN是機(jī)器學(xué)習(xí)近十年來(lái)最有意思的想法。
原文地址:
https://blog.csdn.net/on2way/article/details/72773771
原文發(fā)布時(shí)間為:2018-07-10
本文來(lái)自云棲社區(qū)合作伙伴“機(jī)器學(xué)習(xí)算法與Python學(xué)習(xí)”,了解相關(guān)信息可以關(guān)注“機(jī)器學(xué)習(xí)算法與Python學(xué)習(xí)”
總結(jié)
以上是生活随笔為你收集整理的一文读懂生成对抗网络(GANs)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 从0开始Vue2集成Bootstrap4
- 下一篇: IDEA中的HTTP Client Ed