从完美KL距离推导VAE
文章目錄
- VAE的純邏輯推導
- 一、初始設定
- VAE流程
- 核心推導
- 從理想出發
- 從理想假設 (Assumption1Assumption_1Assumption1?)
- 化簡LossLossLoss
- 采樣優化(Assumption2Assumption_2Assumption2?)
- 二、樣例解釋
- 三、總結
VAE的純邏輯推導
一、初始設定
-
為了更具體的邏輯討論,我們假定表示輸入圖片的隨機變量XXX(以下簡稱輸入圖片),表示編碼的隨機變量hhh(以下簡稱編碼),表示得到的重構圖片的隨機變量X^\hat XX^(以下簡稱重構圖片)
-
對于編碼器和解碼器,我們用三個變量來表示,{?\phi?, hhh, XXX},其中?\phi?表示“器”的所有參數,hhh表示編碼,XXX表示圖片(對于編碼器是輸入圖片,對于解碼器是重構圖片)
-
我們所有的假設都是趨于【完美】,所有的假設的根本出發點都是為了設計出能求的算法
-
P?(X,h)P_{\phi}(X,h)P??(X,h)表示在給定?\phi?的參數下,XXX和hhh的聯合分布
-
KL(P(x)∣∣P(y))KL(P(x)||P(y))KL(P(x)∣∣P(y))表示兩個分布xxx, yyy的KL距離
VAE流程
XXX →\rightarrow→ Encoder(參數:?\phi?) →\rightarrow→ hhh →\rightarrow→ Decoder(參數:θ\thetaθ) →\rightarrow→ X^\hat XX^
核心推導
從理想出發
【完美】情況下,我們希望得到一個完全相同的編碼器和解碼器,即
Loss(?,θ,X,h,X^)=KL(P?(X,h)∣∣Pθ(X^,h))=0Loss(\phi, \theta,X,h,\hat{X}) = KL(P_{\phi}(X,h)||P_{\theta}(\hat{X},h))=0 Loss(?,θ,X,h,X^)=KL(P??(X,h)∣∣Pθ?(X^,h))=0
但是可能得不到,那就盡量減少KL(?∣∣?)KL(\cdot||\cdot)KL(?∣∣?)吧(KL距離非負):
min?Loss(?,θ,X,h,X^)=KL(P?(X,h)∣∣Pθ(X^,h))\min \quad Loss(\phi, \theta,X,h,\hat{X}) =KL(P_{\phi}(X,h)||P_{\theta}(\hat{X},h)) minLoss(?,θ,X,h,X^)=KL(P??(X,h)∣∣Pθ?(X^,h))
數學知識告訴我們
Loss(?,θ,X,h,X^)=KL(P?(X,h)∣∣Pθ(X^,h))=∑X,h,X^P?(h,X)log?P?(h,X)Pθ(h,X^)=∑X,h,X^P?(X)P?(h∣X)log?P?(X)P?(h∣X)Pθ(X^)Pθ(h∣X^)\begin{aligned} Loss(\phi, \theta,X,h,\hat{X}) = &KL(P_{\phi}(X,h)||P_{\theta}(\hat{X},h))\\ =& \sum_{X,h,\hat{X}} P_{\phi}(h,X)\log\frac{P_{\phi}(h,X)}{P_{\theta}(h,\hat{X})} \\ =&\sum_{X,h,\hat{X}} P_\phi(X)P_{\phi}(h|X)\log\frac{P_\phi(X)P_{\phi}(h|X)}{P_{\theta}(\hat{X})P_{\theta}(h|\hat{X})} \end{aligned} Loss(?,θ,X,h,X^)===?KL(P??(X,h)∣∣Pθ?(X^,h))X,h,X^∑?P??(h,X)logPθ?(h,X^)P??(h,X)?X,h,X^∑?P??(X)P??(h∣X)logPθ?(X^)Pθ?(h∣X^)P??(X)P??(h∣X)??
到此為止,能推導的部分已經結束了,在向下就需要做一些假設才能向下推,甚至于有一些假設與原設定相悖,但是沒辦法,不假設就推不下去GG了
從理想假設 (Assumption1Assumption_1Assumption1?)
【完美】情況下,我們最終得到的XXX和X^\hat{X}X^應該是一模一樣的,所以我們假設重構的X^\hat{X}X^的分布和XXX一樣,所以
Loss(?,θ,X,h)=KL(P?(X,h)∣∣Pθ(X,h))=∑X,hP?(X)P?(h∣X)log?P?(X)P?(h∣X)Pθ(X)Pθ(h∣X)=∑X,hP?(X)P?(h∣X)log?P?(X)P?(h∣X)Pθ(X)Pθ(h∣X)=∑X,hP?(X)P?(h∣X)log?P?(X)Pθ(X)+∑X,hP?(X)P?(h∣X)log?P?(h∣X)Pθ(h∣X)=∑XP?(X)log?P?(X)Pθ(X)+∑XP?(X)∑hP?(h∣X)log?P?(h∣X)Pθ(h∣X)=KL(P?(X)∣∣Pθ(X))+KL(P?(h∣X)∣∣Pθ(h∣X))\begin{aligned} Loss(\phi, \theta,X,h) =& KL(P_{\phi}(X,h)||P_{\theta}(X,h)) \\ =& \sum_{X,h} P_\phi(X)P_{\phi}(h|X)\log\frac{P_\phi(X)P_{\phi}(h|X)}{P_{\theta}(X)P_{\theta}(h|X)}\\ =& \sum_{X,h} P_\phi(X)P_{\phi}(h|X)\log\frac{P_\phi(X)P_{\phi}(h|X)}{P_{\theta}(X)P_{\theta}(h|X)}\\ =& \sum_{X,h} P_\phi(X)P_{\phi}(h|X)\log\frac{P_\phi(X)}{P_{\theta}(X)} + \sum_{X,h} P_\phi(X)P_{\phi}(h|X)\log\frac{P_{\phi}(h|X)}{P_{\theta}(h|X)}\\ =& \sum_{X} P_\phi(X)\log\frac{P_\phi(X)}{P_{\theta}(X)} + \sum_{X} P_\phi(X)\sum_hP_{\phi}(h|X)\log\frac{P_{\phi}(h|X)}{P_{\theta}(h|X)}\\ =& KL(P_\phi(X)||P_{\theta}(X)) + KL(P_{\phi}(h|X)||P_{\theta}(h|X)) \end{aligned} Loss(?,θ,X,h)======?KL(P??(X,h)∣∣Pθ?(X,h))X,h∑?P??(X)P??(h∣X)logPθ?(X)Pθ?(h∣X)P??(X)P??(h∣X)?X,h∑?P??(X)P??(h∣X)logPθ?(X)Pθ?(h∣X)P??(X)P??(h∣X)?X,h∑?P??(X)P??(h∣X)logPθ?(X)P??(X)?+X,h∑?P??(X)P??(h∣X)logPθ?(h∣X)P??(h∣X)?X∑?P??(X)logPθ?(X)P??(X)?+X∑?P??(X)h∑?P??(h∣X)logPθ?(h∣X)P??(h∣X)?KL(P??(X)∣∣Pθ?(X))+KL(P??(h∣X)∣∣Pθ?(h∣X))?
這樣一來,聯合分布的KL距離就轉化為,邊緣分布的KL距離與條件分布的KL距離之和,那么換言之,也就是說loss變成了兩項:
化簡LossLossLoss
LossLossLoss的第一項是重構圖片X^\hat XX^和原圖片XXX的某種差距,我們可以化簡為
Loss1(?,θ,X,h)=KL(P?(X)∣∣Pθ(X))=∑XP?(X)log?P?(X)Pθ(X)=?Entropy(P?(X))?∑XP?(X)log?Pθ(X)=Constant?∑XP?(X)log?Pθ(X)=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)Pθ(h)Pθ(h∣X)Pθ(X)Pθ(X)+Constant=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)Pθ(h)Pθ(h∣X)P?(h∣X)P?(h∣X)+Constant=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)+∑XP?(X)∑hP?(h∣X)log?P?(h∣X)Pθ(h)?∑XP?(X)∑hP?(h∣X)log?P?(h∣X)Pθ(h∣X)+Constant=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)+∑XP?(X)∑hP?(h∣X)log?P?(h∣X)Pθ(h)?KL(P?(h∣X)∣∣Pθ(h∣X))+Constant=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)+KL(P?(h∣X)∣∣Pθ(h))?KL(P?(h∣X)∣∣Pθ(h∣X))+Constant\begin{aligned} Loss_1(\phi, \theta,X,h) =& KL(P_\phi(X)||P_{\theta}(X)) \\ =& \sum_X P_\phi(X) \log \frac{P_\phi(X)}{P_{\theta}(X)}\\ =& - Entropy(P_\phi(X)) - \sum_X P_\phi(X) \log P_{\theta}(X)\\ =& Constant - \sum_X P_\phi(X) \log P_{\theta}(X)\\ =& - \sum_X P_\phi(X) \sum_h P_\phi(h|X)\log \frac{P_{\theta}(X|h) P_{\theta}(h)}{P_{\theta}(h|X) P_{\theta}(X)}P_{\theta}(X)+ Constant\\ =& - \sum_X P_\phi(X) \sum_h P_\phi(h|X)\log \frac{P_{\theta}(X|h) P_{\theta}(h)}{P_{\theta}(h|X) }\frac{P_\phi(h|X)}{P_\phi(h|X)}+ Constant\\ =&- \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log P_{\theta}(X|h) + \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log \frac{P_\phi(h|X)}{P_{\theta}(h)} - \sum_{X} P_\phi(X)\sum_h P_\phi(h|X) \log\frac{P_\phi(h|X)}{P_{\theta}(h|X)}+ Constant\\ =&- \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log P_{\theta}(X|h) + \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log \frac{P_\phi(h|X)}{P_{\theta}(h)} - KL(P_{\phi}(h|X)||P_{\theta}(h|X))+ Constant\\ =& - \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log P_{\theta}(X|h) + KL(P_\phi(h|X)||P_{\theta}(h)) - KL(P_{\phi}(h|X)||P_{\theta}(h|X))+ Constant \end{aligned} Loss1?(?,θ,X,h)=========?KL(P??(X)∣∣Pθ?(X))X∑?P??(X)logPθ?(X)P??(X)??Entropy(P??(X))?X∑?P??(X)logPθ?(X)Constant?X∑?P??(X)logPθ?(X)?X∑?P??(X)h∑?P??(h∣X)logPθ?(h∣X)Pθ?(X)Pθ?(X∣h)Pθ?(h)?Pθ?(X)+Constant?X∑?P??(X)h∑?P??(h∣X)logPθ?(h∣X)Pθ?(X∣h)Pθ?(h)?P??(h∣X)P??(h∣X)?+Constant?X∑?P??(X)h∑?P??(h∣X)logPθ?(X∣h)+X∑?P??(X)h∑?P??(h∣X)logPθ?(h)P??(h∣X)??X∑?P??(X)h∑?P??(h∣X)logPθ?(h∣X)P??(h∣X)?+Constant?X∑?P??(X)h∑?P??(h∣X)logPθ?(X∣h)+X∑?P??(X)h∑?P??(h∣X)logPθ?(h)P??(h∣X)??KL(P??(h∣X)∣∣Pθ?(h∣X))+Constant?X∑?P??(X)h∑?P??(h∣X)logPθ?(X∣h)+KL(P??(h∣X)∣∣Pθ?(h))?KL(P??(h∣X)∣∣Pθ?(h∣X))+Constant?
我們將Loss1(?,θ,X,h)Loss_1(\phi, \theta,X,h)Loss1?(?,θ,X,h)代回 Loss(?,θ,X,h)Loss(\phi, \theta,X,h)Loss(?,θ,X,h),將ConstantConstantConstant去掉(因為優化的時候常數其實沒有影響),得到
Loss(?,θ,X,h)=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)+KL(P?(h∣X)∣∣Pθ(h))\begin{aligned} Loss(\phi, \theta,X,h) =& - \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log P_{\theta}(X|h) + KL(P_\phi(h|X)||P_{\theta}(h)) \end{aligned} Loss(?,θ,X,h)=??X∑?P??(X)h∑?P??(h∣X)logPθ?(X∣h)+KL(P??(h∣X)∣∣Pθ?(h))?
神奇么?兩個編碼器之間的某種差距,KL(P?(h∣X)∣∣Pθ(h∣X))KL(P_{\phi}(h|X)||P_{\theta}(h|X))KL(P??(h∣X)∣∣Pθ?(h∣X))消掉了。
好耶!到此為止,我們得到了和CS231n完全一致的結論!
但是具體應該怎么優化呢,我們就需要第二個假設了。
采樣優化(Assumption2Assumption_2Assumption2?)
讓我們重新回顧一下涉及到的參數:
讓我們再看一看優化目標
Loss(?,θ,X,h)=?∑XP?(X)∑hP?(h∣X)log?Pθ(X∣h)+KL(P?(h∣X)∣∣Pθ(h))\begin{aligned} Loss(\phi, \theta,X,h) =& - \sum_{X} P_\phi(X) \sum_h P_\phi(h|X)\log P_{\theta}(X|h) + KL(P_\phi(h|X)||P_{\theta}(h)) \end{aligned} Loss(?,θ,X,h)=??X∑?P??(X)h∑?P??(h∣X)logPθ?(X∣h)+KL(P??(h∣X)∣∣Pθ?(h))?
好家伙,辛辛苦苦得出的LossLossLoss里面有Pθ(X∣h)P_{\theta}(X|h)Pθ?(X∣h)和Pθ(h)P_{\theta}(h)Pθ?(h),這怎么搞呀?
難道要根據特定的?\phi?,取眾多的XXX,得到眾多的hhh,才算出這倆概率嘛?
很顯然,計算量巨大,而且每更新一次?\phi?,又得走一遍流程,你懂計算機的痛苦嘛?
所以,我們需要新的假設來簡化計算:
【完美】情況下,假設hhh服從正態分布,也即h~N(μ,Σ)h\sim N(\mu, \Sigma)h~N(μ,Σ)。這樣一來,我們的LossLossLoss就可以近似計算了。
- 注意,這樣的假設其實是與初始條件矛盾的。根據原定假設,hhh 是由 XXX 和 DecoderDecoderDecoder 確定的,換言之 hhh 的分布取決于 XXX 和 DecoderDecoderDecoder 的分布。但是這樣一來,為了求出 hhh 的分布我們需要大量采樣XXX,來求得Pθ(X∣h)P_{\theta}(X|h)Pθ?(X∣h),很難實現,所以就選擇將編碼 hhh 看作是高斯分布。
那么我們就可以算出LossLossLoss,然后用Back Propagation等方法優化啦~
二、樣例解釋
下面根據cs321n的PPT,解釋一下怎么優化
特別的,我們 LossLossLoss 里是對所有的 XXX 進行優化,用batch的話,可以記錄之前幾輪batch的loss,在新的一輪添加上動量項,來綜合考量進行優化
Loss(epochi+1)=γLoss(epochi)+(1?γ)Loss(thisepoch)Loss(epoch_{i+1}) = \gamma Loss(epoch_{i}) + (1-\gamma)Loss(this \; epoch) Loss(epochi+1?)=γLoss(epochi?)+(1?γ)Loss(thisepoch)
三、總結
用全新的角度梳理一遍VAE是真的難頂,推了幾次推不下去就存成草稿,后來不甘心,又打開草稿繼續推,終于用我自己覺得嚴謹的邏輯推完了VAE。有一說一,推完之后自己清楚了許多。
以前覺得VAE都是bug,弄了一堆不自恰的東西,現在看來,也就是兩個假設,剩下的夾著的都是純邏輯推導,就如同兩面包夾芝士(^ _ ^)。
這里特別感謝cs231n的PPT(后半段推導),上海交通大學張拳石老師的機器學習課程(前半段推導)。
總結
以上是生活随笔為你收集整理的从完美KL距离推导VAE的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: php一些错误的显示问题
- 下一篇: MAC 下shell工具推荐 zente