变分自编码器VAE:一步到位的聚类方案
作者丨蘇劍林
單位丨廣州火焰信息科技有限公司
研究方向丨NLP,神經(jīng)網(wǎng)絡(luò)
個人主頁丨kexue.fm
由于 VAE 中既有編碼器又有解碼器(生成器),同時隱變量分布又被近似編碼為標準正態(tài)分布,因此 VAE 既是一個生成模型,又是一個特征提取器。
在圖像領(lǐng)域中,由于 VAE 生成的圖片偏模糊,因此大家通常更關(guān)心 VAE 作為圖像特征提取器的作用。提取特征都是為了下一步的任務準備的,而下一步的任務可能有很多,比如分類、聚類等。本文來關(guān)心“聚類”這個任務。
一般來說,用 AE 或者 VAE 做聚類都是分步來進行的,即先訓練一個普通的 VAE,然后得到原始數(shù)據(jù)的隱變量,接著對隱變量做一個 K-Means 或 GMM 之類的。但是這樣的思路的整體感顯然不夠,而且聚類方法的選擇也讓我們糾結(jié)。
本文介紹基于 VAE 的一個“一步到位”聚類思路,它同時允許我們完成無監(jiān)督地完成聚類和條件生成。
理論
一般框架
回顧 VAE 的 loss(如果沒印象請參考再談變分自編碼器VAE:從貝葉斯觀點出發(fā)):
通常來說,我們會假設(shè) q(z) 是標準正態(tài)分布,p(z|x),q(x|z) 是條件正態(tài)分布,然后代入計算,就得到了普通的 VAE 的 loss。
然而,也沒有誰規(guī)定隱變量一定是連續(xù)變量吧?這里我們就將隱變量定為 (z,y),其中 z 是一個連續(xù)變量,代表編碼向量;y 是離散的變量,代表類別。直接把 (1) 中的 z 替換為 (z,y),就得到:
這就是用來做聚類的 VAE 的 loss 了。
分步假設(shè)
啥?就完事了?呃,是的,如果只考慮一般化的框架,(2) 確實就完事了。?
不過落實到實踐中,(2) 可以有很多不同的實踐方案,這里介紹比較簡單的一種。首先我們要明確,在 (2 )中,我們只知道 p?(x)(通過一批數(shù)據(jù)給出的經(jīng)驗分布),其他都是沒有明確下來的。于是為了求解 (2),我們需要設(shè)定一些形式。一種選取方案為:
代入 (2) 得到:
其實 (4) 式還是相當直觀的,它分布描述了編碼和生成過程:
1. 從原始數(shù)據(jù)中采樣到 x,然后通過 p(z|x) 可以得到編碼特征 z,然后通過分類器 p(y|z) 對編碼特征進行分類,從而得到類別;
2. 從分布 q(y) 中選取一個類別 y,然后從分布 q(z|y) 中選取一個隨機隱變量 z,再通過生成器 q(x|z) 解碼為原始樣本。
具體模型
(4) 式其實已經(jīng)很具體了,我們只需要沿用以往 VAE 的做法:p(z|x) 一般假設(shè)為均值為 μ(x) 方差為的正態(tài)分布,q(x|z) 一般假設(shè)為均值為 G(z) 方差為常數(shù)的正態(tài)分布(等價于用 MSE 作為 loss),q(z|y) 可以假設(shè)為均值為 μy 方差為 1 的正態(tài)分布,至于剩下的 q(y),p(y|z),q(y) 可以假設(shè)為均勻分布(它就是個常數(shù)),也就是希望每個類大致均衡,而 p(y|z) 是對隱變量的分類器,隨便用個 softmax 的網(wǎng)絡(luò)就可以擬合了。?
最后,可以形象地將 (4) 改寫為:
其中 z~p(z|x) 是重參數(shù)操作,而方括號中的三項 loss,各有各的含義:
1. ?log q(x|z) 希望重構(gòu)誤差越小越好,也就是 z 盡量保留完整的信息;
2.希望 z 能盡量對齊某個類別的“專屬”的正態(tài)分布,就是這一步起到聚類的作用;
3. KL(p(y|z)‖q(y)) 希望每個類的分布盡量均衡,不會發(fā)生兩個幾乎重合的情況(坍縮為一個類)。當然,有時候可能不需要這個先驗要求,那就可以去掉這一項。
實驗
實驗代碼自然是 Keras 完成的了,在 MNIST 和 Fashion-MNIST 上做了實驗,表現(xiàn)都還可以。實驗環(huán)境:Keras 2.2 + TensorFlow 1.8 + Python 2.7。
代碼實現(xiàn)
代碼位于:
https://github.com/bojone/vae/blob/master/vae_keras_cluster.py?
其實注釋應該比較清楚了,而且相比普通的 VAE 改動不大。可能稍微有難度的是這個怎么實現(xiàn)。因為 y 是離散的,所以事實上這就是一個矩陣乘法(相乘然后對某個公共變量求和,就是矩陣乘法的一般形式),用 K.batch_dot 實現(xiàn)。?
其他的話,讀者應該先弄清楚普通的 VAE 實現(xiàn)過程,然后再看本文的內(nèi)容和代碼,不然估計是一臉懵的。
MNIST
這里是 MNIST?的實驗結(jié)果圖示,包括類內(nèi)樣本圖示和按類采樣圖示。最后還簡單估算了一下,以每一類對應的數(shù)目最多的那個真實標簽為類標簽的話,最終的 test 準確率大約有 84.5%,對比這篇文章 Unsupervised Deep Embedding for Clustering Analysis [1]?的結(jié)果(最高也是 84% 左右),感覺應該很不錯了。?
聚類圖示
按類采樣
Fashion-MNIST
這里是 Fashion-MNIST [2]?的實驗結(jié)果圖示,包括類內(nèi)樣本圖示和按類采樣圖示,最終的 test 準確率大約有 60.6%。?
聚類圖示
按類采樣
總結(jié)
文章簡單地實現(xiàn)了一下基于 VAE 的聚類算法,算法的特點就是一步到位,結(jié)合“編碼”、“聚類”和“生成”三個任務同時完成,思想是對 VAE 的 loss 的一般化。
感覺還有一定的提升空間,比如式 (4) 只是式 (2) 的一個例子,還可以考慮更加一般的情況。代碼中的 encoder 和 decoder 也都沒有經(jīng)過仔細調(diào)優(yōu),僅僅是驗證想法所用。
參考文獻
[1].?Unsupervised Deep Embedding for Clustering Analysis Junyuan Xie, Ross Girshick, and Ali Farhadi in International Conference on Machine Learning (ICML), 2016.
[2].?https://github.com/zalandoresearch/fashion-mnist
點擊以下標題查看更多相關(guān)文章:?
變分自編碼器VAE:原來是這么一回事 | 附開源代碼
再談變分自編碼器VAE:從貝葉斯觀點出發(fā)
變分自編碼器VAE:這樣做為什么能成?
漫談概率 PCA 和變分自編碼器
全新視角:用變分推斷統(tǒng)一理解生成模型
PaperWeekly 第二十七期 | VAE for NLP
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢??答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學習心得或技術(shù)干貨。我們的目的只有一個,讓知識真正流動起來。
??來稿標準:
? 稿件確系個人原創(chuàng)作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發(fā),請在投稿時提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認每篇文章都是首發(fā),均會添加“原創(chuàng)”標志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發(fā)送?
? 請留下即時聯(lián)系方式(微信或手機),以便我們在編輯發(fā)布時和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術(shù)平臺。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點擊 |?閱讀原文?| 查看作者博客
總結(jié)
以上是生活随笔為你收集整理的变分自编码器VAE:一步到位的聚类方案的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ACL 2018论文解读 | 基于路径的
- 下一篇: 近期大热的AutoML领域,都有哪些值得