用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
原文地址:https://zhuanlan.zhihu.com/p/266311690
論文地址:https://arxiv.org/pdf/2010.11929.pdf
代碼地址:https://github.com/google-research/vision_transformer
用Transformer完全代替CNN
- 1. Story
- 2. Model
- a 將圖像轉(zhuǎn)化為序列化數(shù)據(jù)
- b Position embedding
- c Learnable embedding
- d Transformer encoder
- 3. 混合結(jié)構(gòu)
- 4. Fine-tuning過(guò)程中高分辨率圖像的處理
- 5. 實(shí)驗(yàn)
1. Story
近年來(lái),Transformer已經(jīng)成了NLP領(lǐng)域的標(biāo)準(zhǔn)配置,但是CV領(lǐng)域還是CNN(如ResNet, DenseNet等)占據(jù)了絕大多數(shù)的SOTA結(jié)果。
最近CV界也有很多文章將transformer遷移到CV領(lǐng)域,這些文章總的來(lái)說(shuō)可以分為兩個(gè)大類:
- 將self-attention機(jī)制與常見(jiàn)的CNN架構(gòu)結(jié)合;
- 用self-attention機(jī)制完全替代CNN。
本文采用的也是第2種思路。雖然已經(jīng)有很多工作用self-attention完全替代CNN,且在理論上效率比較高,但是它們用了特殊的attention機(jī)制,無(wú)法從硬件層面加速,所以目前CV領(lǐng)域的SOTA結(jié)果還是被CNN架構(gòu)所占據(jù)。
文章不同于以往工作的地方,就是盡可能地將NLP領(lǐng)域的transformer不作修改地搬到CV領(lǐng)域來(lái)。但是NLP處理的語(yǔ)言數(shù)據(jù)是序列化的,而CV中處理的圖像數(shù)據(jù)是三維的(長(zhǎng)、寬和channels)。
所以我們需要一個(gè)方式將圖像這種三維數(shù)據(jù)轉(zhuǎn)化為序列化的數(shù)據(jù)。文章中,圖像被切割成一個(gè)個(gè)patch,這些patch按照一定的順序排列,就成了序列化的數(shù)據(jù)。(具體將在下面講述)
在實(shí)驗(yàn)中,作者發(fā)現(xiàn),在中等規(guī)模的數(shù)據(jù)集上(例如ImageNet),transformer模型的表現(xiàn)不如ResNets;而當(dāng)數(shù)據(jù)集的規(guī)模擴(kuò)大,transformer模型的效果接近或者超過(guò)了目前的一些SOTA結(jié)果。作者認(rèn)為是大規(guī)模的訓(xùn)練可以鼓勵(lì)transformer學(xué)到CNN結(jié)構(gòu)所擁有的translation equivariance和locality.
2. Model
Vision Transformer (ViT)結(jié)構(gòu)示意圖
模型的結(jié)構(gòu)其實(shí)比較簡(jiǎn)單,可以分成以下幾個(gè)部分來(lái)理解:
a 將圖像轉(zhuǎn)化為序列化數(shù)據(jù)
作者采用了了一個(gè)比較簡(jiǎn)單的方式。如下圖所示。首先將圖像分割成一個(gè)個(gè)patch,然后將每個(gè)patch reshape成一個(gè)向量,得到所謂的flattened patch。
具體地,如果圖片是H×W×CH\times W\times CH×W×C維,用P×PP\times PP×P大小的patch去分割圖片可以得到N個(gè)patch,那么每個(gè)patch的shape就是P×P×CP\times P\times CP×P×C,轉(zhuǎn)化為向量就是P2CP^2CP2C維向量,將N個(gè)patch reshape后的向量concat在一起就得到了一個(gè)N×(P2C)N\times (P^2C)N×(P2C)的二維矩陣,相當(dāng)于NLP中輸入transformer的詞向量。
- 分割圖像得到patch
從上面的過(guò)程可以看出,當(dāng)patch的大小變化時(shí)(即 P 變化時(shí)),每個(gè)patch reshape后得到的 P2CP^2CP2C 維向量的長(zhǎng)度也會(huì)變化。為了避免模型結(jié)構(gòu)受到patch size的影響,作者對(duì)上述過(guò)程得到的flattened patches向量做了Linear Projection(如下圖所示),將不同長(zhǎng)度的flattened patch向量轉(zhuǎn)化為固定長(zhǎng)度的向量(記做D維向量)。
- 對(duì)flattened patches做linear projection
綜上,原本H×W×CH\times W\times CH×W×C維的圖片被轉(zhuǎn)化為N個(gè)D維的向量(或者一個(gè)N×DN\times DN×D維的二維矩陣)。
b Position embedding
- Position embedding
由于transformer模型本身是沒(méi)有位置信息的,和NLP中一樣,我們需要用position embedding將位置信息加到模型中去。
如上圖所示1,編號(hào)有0-9的紫色框表示各個(gè)位置的position embedding,而紫色框旁邊的粉色框則是經(jīng)過(guò)linear projection之后的flattened patch向量。文中采用將position embedding(即圖中紫色框)和patch embedding(即圖中粉色框)相加的方式結(jié)合position信息。
c Learnable embedding
如果大家仔細(xì)看上圖,就會(huì)發(fā)現(xiàn)帶星號(hào)的粉色框(即0號(hào)紫色框右邊的那個(gè))不是通過(guò)某個(gè)patch產(chǎn)生的。這個(gè)是一個(gè)learnable embedding(記作 XclassX_{class}Xclass? ),其作用類似于BERT中的[class] token。在BERT中,[class] token經(jīng)過(guò)encoder后對(duì)應(yīng)的結(jié)果作為整個(gè)句子的表示;類似地,這里 XclassX_{class}Xclass? 經(jīng)過(guò)encoder后對(duì)應(yīng)的結(jié)果也作為整個(gè)圖的表示。
至于為什么BERT或者這篇文章的ViT要多加一個(gè)token呢?因?yàn)槿绻藶榈刂付ㄒ粋€(gè)embedding(例如本文中某個(gè)patch經(jīng)過(guò)Linear Projection得到的embedding)經(jīng)過(guò)encoder得到的結(jié)果作為整體的表示,則不可避免地會(huì)使得整體表示偏向于這個(gè)指定embedding的信息(例如圖像的表示偏重于反映某個(gè)patch的信息)。而這個(gè)新增的token沒(méi)有語(yǔ)義信息(即在句子中與任何的詞無(wú)關(guān),在圖像中與任何的patch無(wú)關(guān)),所以不會(huì)造成上述問(wèn)題,能夠比較公允地反映全圖的信息。
d Transformer encoder
Transformer Encoder結(jié)構(gòu)和NLP中transformer結(jié)構(gòu)基本上相同,所以這里只給出其結(jié)構(gòu)圖,和公式化的計(jì)算過(guò)程,也是順便用公式表達(dá)了之前所說(shuō)的幾個(gè)部分內(nèi)容。
Transformer Encoder的結(jié)構(gòu)如下圖所示:
對(duì)于Encoder的第 lll 層,記其輸入為zl?1z_{l-1}zl?1?,輸出為zlz_lzl?,則計(jì)算過(guò)程為:
其中MSA為Multi-Head Self-Attention(即Transformer Encoder結(jié)構(gòu)圖中的綠色框),MLP為Multi-Layer Perceptron(即Transformer Encoder結(jié)構(gòu)圖中的藍(lán)色框),LN為L(zhǎng)ayer Norm(即Transformer Encoder結(jié)構(gòu)圖中的黃色框)。
Encoder第一層的輸入z0z_0z0?是通過(guò)下面的公式得到的:
其中Xp1,...,XpNX_p^1,...,X_p^NXp1?,...,XpN?即未Linear Projection后的patch embedding(都是p2Cp^2Cp2C維)
3. 混合結(jié)構(gòu)
文中還提出了一個(gè)比較有趣的解決方案,將transformer和CNN結(jié)合,即將ResNet的中間層的feature map作為transformer的輸入。
和之前所說(shuō)的將圖片分成patch然后reshape成sequence不同的是,在這種方案中,作者直接將ResNet某一層的feature map reshape成sequence,再通過(guò)Linear Projection變?yōu)門ransformer輸入的維度,然后直接輸入進(jìn)Transformer中。
4. Fine-tuning過(guò)程中高分辨率圖像的處理
在Fine-tuning到下游任務(wù)時(shí),當(dāng)圖像的分辨率增大時(shí)(即圖像的長(zhǎng)和寬增大時(shí)),如果保持patch大小不變,得到的patch個(gè)數(shù)將增加(記分辨率增大后新的patch個(gè)數(shù)為 N′N^{'}N′ )。但是由于在pretrain時(shí),position embedding的個(gè)數(shù)和pretrain時(shí)分割得到的patch個(gè)數(shù)(即上文中的 N )相同。則多出來(lái)的 N′?NN^{'}-NN′?N 個(gè)positioin embedding在pretrain中是未定義或者無(wú)意義的。
為了解決這個(gè)問(wèn)題,文章中提出用2D插值的方法,基于原圖中的位置信息,將pretrain中的 N 個(gè)position embedding插值成N′N^{'}N′ 個(gè)。這樣在得到 N′N^{'}N′ 個(gè)position embedding的同時(shí)也保證了position embedding的語(yǔ)義信息。
5. 實(shí)驗(yàn)
實(shí)驗(yàn)部分由于涉及到的細(xì)節(jié)較多就不具體介紹了,大家如果感興趣可以參看原文。(不得不說(shuō)Google的實(shí)驗(yàn)?zāi)芰外n能力不是一般人能比的…)
主要的實(shí)驗(yàn)結(jié)論在story中就已經(jīng)介紹過(guò)了,這里復(fù)制粘貼一下:在中等規(guī)模的數(shù)據(jù)集上(例如ImageNet),transformer模型的表現(xiàn)不如ResNets;而當(dāng)數(shù)據(jù)集的規(guī)模擴(kuò)大,transformer模型的效果接近或者超過(guò)了目前的一些SOTA結(jié)果。
比較有趣的是,作者還做了很多其他的分析來(lái)解釋transfomer的合理性。大家如果感興趣也可以參看原文,這里放幾張文章中的圖。
總結(jié)
以上是生活随笔為你收集整理的用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【腾讯面试题】Nginx
- 下一篇: leetcode303 Range Su