【生成模型】解读显式生成模型之完全可见置信网络FVBN
上一期為大家說(shuō)明了什么是極大似然法,以及如何使用極大似然法搭建生成模型,本期將為大家介紹第一個(gè)顯式生成模型完全可見(jiàn)置信網(wǎng)絡(luò)FVBN。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 作者&編輯 | 小米粥
1 完全可見(jiàn)置信網(wǎng)絡(luò)
在完全可見(jiàn)置信網(wǎng)絡(luò)中,不存在不可觀察的潛在變量,觀察變量的概率被鏈?zhǔn)椒▌t從維度上進(jìn)行分解,對(duì)于 n 維觀察變量x ,其概率表達(dá)式為:
自回歸網(wǎng)絡(luò)是最簡(jiǎn)單的完全可見(jiàn)置信網(wǎng)絡(luò),其中每一個(gè)維度的觀察變量都構(gòu)成概率模型的一個(gè)節(jié)點(diǎn),而這些所有的節(jié)點(diǎn){x1,x2,...,xn}共同構(gòu)成一個(gè)完全有向圖,即圖中任意兩個(gè)節(jié)點(diǎn)都存在連接關(guān)系,如圖所示。
在自回歸網(wǎng)絡(luò)中,因?yàn)橐呀?jīng)有了隨機(jī)變量的鏈?zhǔn)椒纸怅P(guān)系,那么核心問(wèn)題便成為如何表達(dá)條件概率p(xi|xi-1,xx-2,...,x1) 。最簡(jiǎn)單的模型是線性自回歸網(wǎng)絡(luò),即每個(gè)條件概率均被定義為線性模型,對(duì)實(shí)數(shù)值數(shù)據(jù)使用線性回歸模型(例如定義 p(xi|xi-1,xx-2,...,x1)= w1x1+w2x2+...+wi-1xi-1?,對(duì)二值數(shù)據(jù)使用邏輯回歸,而對(duì)離散數(shù)據(jù)使用softmax回歸,其計(jì)算過(guò)程如下圖。
但線性模型容量有限,擬合函數(shù)的能力不足。在神經(jīng)自回歸網(wǎng)絡(luò)中,使用神經(jīng)網(wǎng)絡(luò)代替線性模型,它可以任意增加容量,理論上可以擬合任意聯(lián)合分布。神經(jīng)自回歸網(wǎng)絡(luò)還使用了特征重用的技巧,神經(jīng)網(wǎng)絡(luò)從觀察變量 xi 學(xué)習(xí)到的隱藏抽象特征 hi 不僅在計(jì)算p(xi+1|xi,xi-1,...,x1)時(shí)使用,也會(huì)在計(jì)算p(xi+2|xi+1,xi,...,x1)時(shí)進(jìn)行重用,其計(jì)算圖如下所示,并且該模型不需要將每個(gè)條件概率的計(jì)算都分別使用不同神經(jīng)網(wǎng)絡(luò)表示,可以將所有神經(jīng)網(wǎng)絡(luò)整合為一個(gè),因此只要設(shè)計(jì)成抽象特征hi只依賴(lài)于x1,x2,...,xi即可。而目前的神經(jīng)自回歸密度估計(jì)器是神經(jīng)自回歸網(wǎng)絡(luò)中最具有代表性的方案,它是在神經(jīng)自回歸網(wǎng)絡(luò)中引入了參數(shù)共享的方案,即從觀察變量xi到任意隱藏抽象特征 hi+1,hi+2,... 的權(quán)值參數(shù)是共享的,使用了特征重用、參數(shù)共享等深度學(xué)習(xí)技巧的神經(jīng)自回歸密度估計(jì)器具有非常優(yōu)秀的性能。
PixelRNN和PixelCNN也屬于完全可見(jiàn)置信網(wǎng)絡(luò),從名字可以看出,這兩個(gè)模型一般用于圖像的生成。它們將圖像x的概率p(x)按照像素分解為 n 個(gè)條件概率的乘積,其中n為圖像的像素點(diǎn)個(gè)數(shù),即在每一個(gè)像素點(diǎn)上定義了一個(gè)條件概率用以表達(dá)像素之間的依賴(lài)關(guān)系,該條件概率分別使用RNN或者CNN進(jìn)行學(xué)習(xí)。為了將輸出離散化,通常將RNN或CNN的最后一層設(shè)置為softmax層,用以表示其輸出不同像素值的概率。在PixelRNN中,一般定義從左上角開(kāi)始沿著右方和下方依次生成每一個(gè)像素點(diǎn),如下圖所示。這樣,對(duì)數(shù)似然的表達(dá)式便可以得到,訓(xùn)練模型時(shí)只需要將其極大化即可。
PixelRNN在其感受野內(nèi)可能具有無(wú)邊界的依賴(lài)范圍,因?yàn)榇笪恢玫南袼刂狄蕾?lài)之前所有已知像素點(diǎn)的像素值,這將需要大量的計(jì)算代價(jià),PixelCNN使用標(biāo)準(zhǔn)卷積層來(lái)捕獲有界的感受野,其訓(xùn)練速度要快于PixelRNN。在PixelCNN中,每個(gè)位置的像素值僅與其周?chē)阎袼攸c(diǎn)的值有關(guān),如下圖所示。灰色部分為已知像素,而白色部分為未知像素,計(jì)算黑色位置的像素值時(shí),需要把方框區(qū)域內(nèi)的所有灰色像素值傳遞給CNN,由CNN最后的softmax輸出層來(lái)表達(dá)表在黑色位置取不同像素值的概率,這里可以使用由0和1構(gòu)成的掩模矩陣將方框區(qū)域內(nèi)的白色位置像素抹掉。PixelRNN和PixelCNN此后仍有非常多改進(jìn)模型,但由于它是逐個(gè)像素點(diǎn)地生成圖片,具有串行性,故在實(shí)際應(yīng)用中效率難以保證,這也是FVBN模型的通病。
2 pixelCNN 代碼
接下來(lái)我們將提供一份完整的pixelCNN的代碼講解,其中訓(xùn)練集為mnist數(shù)據(jù)集。
首先讀取相關(guān)python庫(kù),設(shè)置訓(xùn)練參數(shù):
# 讀取相關(guān)庫(kù)?
import time?
import torch?
import torch.nn.functional as F?
from torch import nn, optim, cudafrom torch.utils?
import datafrom torchvision import datasets, transforms, utils?
# 設(shè)置訓(xùn)練參數(shù)?
train_batch_size = 256?
generation_batch_size = 48?
epoch_number = 25feature_dim = 64?
# 是否使用GPU?
if torch.cuda.is_available(): ? ?
????device = torch.device('cuda:0')?
else: ? ?
????device = torch.device('cpu')
然后定義二維掩膜卷積,所謂掩膜即使卷積中心的右方和下方的權(quán)值為0,如下圖所示為3x3掩膜卷積核(A型):
定義二維掩膜卷積核,其中有A與B兩種類(lèi)型,區(qū)別之處在于中心位置是否被卷積計(jì)算:
class MaskedConv2d(nn.Conv2d):
?? ?def __init__(self, mask_type, *args, **kwargs):
? ? ? ?super(MaskedConv2d, self).__init__(*args, **kwargs)
? ? ? ?assert mask_type in {'A', 'B'}
? ? ? ?self.register_buffer('mask', self.weight.data.clone())
? ? ? ?bs, o_feature_dim, kH, kW = self.weight.size()
? ? ? ?self.mask.fill_(1)
? ? ? ?self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0 ? ? ? ?
????????self.mask[:, :, kH // 2 + 1:] = 0
????def forward(self, x): ? ? ? ?
????????self.weight.data *= self.mask ? ? ? ?
????????return super(MaskedConv2d, self).forward(x)
我們的pixelCNN網(wǎng)絡(luò)為多層掩膜卷積的堆疊,即:
network = nn.Sequential(
????MaskedConv2d('A',1,feature_dim,7,1,3, bias=False),nn.BatchNorm2d(feature_dim),nn.ReLU(True),? ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),? ??
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?
????MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), ? ?nn.Conv2d(feature_dim, 256, 1))?
network.to(device)
接著設(shè)置dataloader和優(yōu)化器:
train_data = data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()), ? ? ? ? ? ? ? ? ? ? batch_size=train_batch_size, shuffle=True, num_workers=1, pin_memory=True)?
test_data = data.DataLoader(datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor()), ? ? ? ? ? ? ? ? ? ? batch_size=train_batch_size, shuffle=False, num_workers=1, pin_memory=True)?
?optimizer = optim.Adam(network.parameters())
開(kāi)始訓(xùn)練網(wǎng)絡(luò),并在每一輪epoch后進(jìn)行測(cè)試和生成樣本
if __name__ == "__main__":
?? ?for epoch in range(epoch_number):
?? ? ? ?# 訓(xùn)練
?? ? ? ?cuda.synchronize()
?? ? ? ?network.train(True)
?? ? ? ?for input_image, _ in train_data:
?? ? ? ? ? ?time_tr = time.time()
?? ? ? ? ? ?input_image = input_image.to(device)
?? ? ? ? ? ?output_image = network(input_image)
?? ? ? ? ? ?target = (input_image.data[:, 0] * 255).long().to(device)
?? ? ? ? ? ?loss = F.cross_entropy(output_image, target)
?? ? ? ? ? ?optimizer.zero_grad()
?? ? ? ? ? ?loss.backward()
?? ? ? ? ? ?optimizer.step()
?? ? ? ? ? ?print("train: {} epoch, loss: {}, cost time: {}".format(epoch, loss.item(), time.time() - time_tr)) ? ? ? ?cuda.synchronize()
?? ? ? ?# 測(cè)試
?? ? ? ?with torch.no_grad():
?? ? ? ? ? ?cuda.synchronize()
?? ? ? ? ? ?time_te = time.time()
?? ? ? ? ? ?network.train(False)
?? ? ? ? ? ?for input_image, _ in test_data:? ? ? ? ? ? ? ? ????????????? ? ? ? ? ? ? ? ? input_image = input_image.to(device)
?? ? ? ? ? ? ? ?target = (input_image.data[:, 0] * 255).long().to(device)
?? ? ? ? ? ? ? ?loss = F.cross_entropy(network(input_image), target)
?? ? ? ? ? ?cuda.synchronize()
?? ? ? ? ? ?time_te = time.time() - time_te
?? ? ? ? ? ?print("test: {} epoch, loss: {}, cost time: {}".format(epoch, loss.item(), time_te)) ? ? ? ?
# 生成樣本
?? ? ? ?with torch.no_grad():
?? ? ? ? ? ?image = torch.Tensor(generation_batch_size, 1, 28, 28).to(device)
?? ? ? ? ? ?image.fill_(0)
?? ? ? ? ? ?network.train(False)
?? ? ? ? ? ?for i in range(28):
?? ? ? ? ? ? ? ?for j in range(28):
?? ? ? ? ? ? ? ? ? ?out = network(image)
?? ? ? ? ? ? ? ? ? ?probs = F.softmax(out[:, :, i, j]).data
?? ? ? ? ? ? ? ? ? ?image[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
?? ? ? ? ? ?utils.save_image(image, 'generation-image_{:02d}.png'.format(epoch), nrow=12, padding=0)
[1]?Oord A V D , Kalchbrenner N , Kavukcuoglu K . Pixel Recurrent Neural Networks[J]. 2016.
[2] 伊恩·古德費(fèi)洛, 約書(shū)亞·本吉奧, 亞倫·庫(kù)維爾. 深度學(xué)習(xí)
總結(jié)
本期帶大家學(xué)習(xí)了第一種顯式生成模型完全可見(jiàn)置信網(wǎng)絡(luò),并對(duì)其中的自回歸網(wǎng)絡(luò)和pixelRNN,pixelCNN做了講解,并講解了一份完整的pixelCNN代碼。下一期我們將對(duì)第二個(gè)顯式模型流模型進(jìn)行講解。
個(gè)人知乎,歡迎關(guān)注
GAN群
有三AI建立了一個(gè)GAN群,便于有志者相互交流。感興趣的同學(xué)也可以微信搜索xiaozhouguo94,備注“加入有三-GAN群”。
更多GAN的學(xué)習(xí)
知識(shí)星球是有三AI的付費(fèi)內(nèi)容社區(qū),里面包超過(guò)100種經(jīng)典GAN模型的解讀,了解詳細(xì)請(qǐng)閱讀以下文章:
【雜談】有三AI知識(shí)星球指導(dǎo)手冊(cè)出爐!和公眾號(hào)相比又有哪些內(nèi)容?
有三AI秋季劃GAN學(xué)習(xí)小組,可長(zhǎng)期跟隨有三學(xué)習(xí)GAN相關(guān)的內(nèi)容,并獲得及時(shí)指導(dǎo),了解詳細(xì)請(qǐng)閱讀以下文章:
【雜談】如何讓2020年秋招CV項(xiàng)目能力更加硬核,可深入學(xué)習(xí)有三秋季劃4大領(lǐng)域32個(gè)方向
轉(zhuǎn)載文章請(qǐng)后臺(tái)聯(lián)系
侵權(quán)必究
往期精選
【GAN優(yōu)化】GAN優(yōu)化專(zhuān)欄上線,首談生成模型與GAN基礎(chǔ)
【GAN的優(yōu)化】從KL和JS散度到fGAN
【GAN優(yōu)化】詳解對(duì)偶與WGAN
【GAN優(yōu)化】詳解SNGAN(頻譜歸一化GAN)
【GAN優(yōu)化】一覽IPM框架下的各種GAN
【GAN優(yōu)化】GAN優(yōu)化專(zhuān)欄欄主小米粥自述,腳踏實(shí)地,莫問(wèn)前程
【GAN優(yōu)化】GAN訓(xùn)練的幾個(gè)問(wèn)題
【GAN優(yōu)化】GAN訓(xùn)練的小技巧
【GAN優(yōu)化】從動(dòng)力學(xué)視角看GAN是一種什么感覺(jué)?
【GAN優(yōu)化】小批量判別器如何解決模式崩潰問(wèn)題
【GAN優(yōu)化】長(zhǎng)文綜述解讀如何定量評(píng)價(jià)生成對(duì)抗網(wǎng)絡(luò)(GAN)
【技術(shù)綜述】有三說(shuō)GANs(上)
【模型解讀】歷數(shù)GAN的5大基本結(jié)構(gòu)
【百戰(zhàn)GAN】如何使用GAN拯救你的低分辨率老照片
【百戰(zhàn)GAN】二次元宅們,給自己做一個(gè)專(zhuān)屬動(dòng)漫頭像可好!
【百戰(zhàn)GAN】羨慕別人的美妝?那就用GAN復(fù)制粘貼過(guò)來(lái)
【百戰(zhàn)GAN】GAN也可以拿來(lái)做圖像分割,看起來(lái)效果還不錯(cuò)?
【百戰(zhàn)GAN】新手如何開(kāi)始你的第一個(gè)生成對(duì)抗網(wǎng)絡(luò)(GAN)任務(wù)
【百戰(zhàn)GAN】自動(dòng)增強(qiáng)圖像對(duì)比度和顏色美感,GAN如何做?
【直播回放】80分鐘剖析GAN如何從各個(gè)方向提升圖像的質(zhì)量
【直播回放】60分鐘剖析GAN如何用于人臉的各種算法
總結(jié)
以上是生活随笔為你收集整理的【生成模型】解读显式生成模型之完全可见置信网络FVBN的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【生成模型】极大似然估计,你必须掌握的概
- 下一篇: 【直播回放】60分钟了解各类图像和视频生