线性Attention的探索:Attention必须有个Softmax吗?
?PaperWeekly 原創(chuàng) ·?作者|蘇劍林
單位|追一科技
研究方向|NLP、神經(jīng)網(wǎng)絡(luò)
眾所周知,盡管基于 Attention 機(jī)制的 Transformer 類模型有著良好的并行性能,但它的空間和時(shí)間復(fù)雜度都是 級(jí)別的,n 是序列長(zhǎng)度,所以當(dāng) n 比較大時(shí) Transformer 模型的計(jì)算量難以承受。
近來(lái),也有不少工作致力于降低 Transformer 模型的計(jì)算量,比如模型剪枝、量化、蒸餾等精簡(jiǎn)技術(shù),又或者修改 Attention 結(jié)構(gòu),使得其復(fù)雜度能降低到 甚至 。
前幾天筆者讀到了論文 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention?[1]?,了解到了線性化?Attention (Linear Attention)這個(gè)探索點(diǎn),繼而閱讀了一些相關(guān)文獻(xiàn),有一些不錯(cuò)的收獲,最后將自己對(duì)線性化 Attention 的理解匯總在此文中。
Attention
當(dāng)前最流行的 Attention 機(jī)制當(dāng)屬 Scaled-Dot Attention [2] ,形式為:
這里的 ,簡(jiǎn)單起見我們就沒(méi)顯式地寫出 Attention 的縮放因子了。
本文我們主要關(guān)心 Self Attention 場(chǎng)景,所以為了介紹上的方便統(tǒng)一設(shè) ,一般長(zhǎng)序列場(chǎng)景下都有 (BERT base 里邊 d=64)。
相關(guān)解讀可以參考筆者的一文讀懂「Attention is All You Need」| 附代碼實(shí)現(xiàn),以及它的一些改進(jìn)工作也可以參考突破瓶頸,打造更強(qiáng)大的 Transformer [3]、Google 新作 Synthesizer:我們還不夠了解自注意力,這里就不多深入介紹了。
1.1 摘掉Softmax
讀者也許想不到,制約 Attention 性能的關(guān)鍵因素,其實(shí)是定義里邊的 Softmax!事實(shí)上,簡(jiǎn)單地推導(dǎo)一下就可以得到這個(gè)結(jié)論。
這一步我們得到一個(gè) 的矩陣,就是這一步?jīng)Q定了 Attention 的復(fù)雜度是 ;如果沒(méi)有 Softmax,那么就是三個(gè)矩陣連乘 ,而矩陣乘法是滿足結(jié)合率的,所以我們可以先算 ,得到一個(gè) 的矩陣,然后再用 左乘它,由于 ,所以這樣算大致的復(fù)雜度只是 (就是 左乘那一步占主導(dǎo))。
也就是說(shuō),去掉 Softmax 的 Attention 的復(fù)雜度可以降到最理想的線性級(jí)別 !這顯然就是我們的終極追求:Linear Attention,復(fù)雜度為線性級(jí)別的 Attention。所以,本文的主題就是探究摘掉 Softmax 后的線形 Attention。
1.2 一般的定義
問(wèn)題是,直接去掉 Softmax 還能算是 Attention 嗎?它還能有標(biāo)準(zhǔn)的 Attention 的效果嗎?為了回答這個(gè)問(wèn)題,我們先將 Scaled-Dot Attention 的定義(1)等價(jià)地改寫為(本文的向量都是列向量)。
所以,Scaled-Dot Attention 其實(shí)就是以 為權(quán)重對(duì) 做加權(quán)平均。所以我們可以提出一個(gè) Attention 的一般化定義:
也就是把 換成 的一般函數(shù) ,為了保留 Attention 的相似特性,我們要求 恒成立。也就是說(shuō),我們?nèi)绻x新式的 Attention,那么要保留式(3)的形式,并且滿足 。
這種一般形式的 Attention 在 CV 中也被稱為 Non-Local 網(wǎng)絡(luò),來(lái)自文章 Non-local Neural Networks [4]。
幾個(gè)例子
如果直接去掉 Softmax,那么就是 ,問(wèn)題是內(nèi)積無(wú)法保證非負(fù)性,所以這還不是一個(gè)合理的選擇。下面我們簡(jiǎn)單介紹幾種可取的方案。
值得指出的是,下面介紹的這幾種 Linear Attention,前兩種只做了 CV 的實(shí)驗(yàn),第三種是筆者自己構(gòu)思的,所以都還沒(méi)有在 NLP 任務(wù)上做過(guò)什么實(shí)驗(yàn),各位做模型改進(jìn)的 NLPer 們就有實(shí)驗(yàn)方向了。
2.1 核函數(shù)形式
一個(gè)自然的想法是:如果 的每個(gè)元素都是非負(fù)的,那么內(nèi)積自然也就是非負(fù)的。為了完成這點(diǎn),我們可以給 各自加個(gè)激活函數(shù) ,即:
其中 是值域非負(fù)的激活函數(shù)。本文開頭提到的論文 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention [5]?選擇的是 。
非要講故事的話,式(4)可以聯(lián)想到“核方法(kernal method)”,尤其是 時(shí) 就相當(dāng)于一個(gè)核函數(shù),而 就是通過(guò)核函數(shù)所定義的內(nèi)積。
這方面的思考可以參考論文 Transformer disp: An unified understanding for transformer’s attention via the lens of kernel [6],此處不做過(guò)多延伸。
2.2 妙用Softmax
另一篇更早的文章 Efficient Attention: Attention with Linear Complexities?[7]?則給出了一個(gè)更有意思的選擇。它留意到在 中,,如果“ 在 d 那一維是歸一化的、并且 在 n 那一維是歸一化的”,那么 就是自動(dòng)滿足歸一化了,所以它給出的選擇是:
其中 、 分別指在第一個(gè)(n)、第二個(gè)維度(d)進(jìn)行 Softmax 運(yùn)算。也就是說(shuō),這時(shí)候我們是各自給 加 Softmax,而不是 算完之后才加 Softmax。
其實(shí)可以證明這個(gè)形式也是式(4)的一個(gè)特例,此時(shí)對(duì)應(yīng)于 ,讀者可以自行推導(dǎo)一下。
2.3 自己的構(gòu)思
在這里,筆者給出自己的一種構(gòu)思。這個(gè)構(gòu)思的出發(fā)點(diǎn)不再是式(4),而是源于我們對(duì)原始定義(2)的近似。由泰勒展開我們有:
如果 ,那么就可以保證右端的非負(fù)性,而從可以讓 。到這里讀者可能已經(jīng)想到了,想要保證 ,只需要分別對(duì) 做 歸一化。所以,筆者最終提出的方案就是:
這不同于形式(4),但理論上它應(yīng)該是最接近原始的 Scaled-Dot Attention 了。
相關(guān)工作
通過(guò)修改 Attention 的形式來(lái)降低它的計(jì)算復(fù)雜度,相關(guān)的工作有很多,這里簡(jiǎn)要列舉一些。
3.1 稀疏Attention
我們之前介紹過(guò) OpenAI 的 Sparse Attention,通過(guò)“只保留小區(qū)域內(nèi)的數(shù)值、強(qiáng)制讓大部分注意力為零”的方式,來(lái)減少 Attention 的計(jì)算量。經(jīng)過(guò)特殊設(shè)計(jì)之后,Attention 矩陣的非 0 元素只有 個(gè),所以理論上它也是一種線性級(jí)別的 Attention。類似的工作還有 Longformer。
但是很明顯,這種思路有兩個(gè)不足之處:
如何選擇要保留的注意力區(qū)域,這是人工主觀決定的,帶有很大的不智能性;
它需要從編程上進(jìn)行特定的設(shè)計(jì)優(yōu)化,才能得到一個(gè)高效的實(shí)現(xiàn),所以它不容易推廣。
3.2 Reformer
Reformer 也是有代表性的改進(jìn)工作,它將 Attention 的復(fù)雜度降到了 。
某種意義上來(lái)說(shuō),Reformer 也是稀疏 Attention 的一種,只不過(guò)它的稀疏 pattern 不是事先指定的,而是通過(guò) LSH(Locality Sensitive Hashing)技術(shù)(近似地)快速地找到最大的若干個(gè) Attention 值,然后只去計(jì)算那若干個(gè)值。
此外,Reformer 通過(guò)構(gòu)造可逆形式的 FFN(Feedforward Network)替換掉原來(lái)的 FFN,然后重新設(shè)計(jì)反向傳播過(guò)程,從而降低了顯存占用量。
所以,相比前述稀疏 Attention,Reformer 解決了它的第一個(gè)缺點(diǎn),但是依然有第二個(gè)缺點(diǎn):實(shí)現(xiàn)起來(lái)復(fù)雜度高。要實(shí)現(xiàn) LSH 形式的 Attention 比標(biāo)準(zhǔn)的 Attention 復(fù)雜多了,對(duì)可逆網(wǎng)絡(luò)重寫反向傳播過(guò)程對(duì)普通讀者來(lái)說(shuō)更是遙不可及。
3.3 Linformer
跟本文所介紹的 Linear Attention 很相似的一個(gè)工作是 Facebook 最近放出來(lái)的 Linformer,它依然保留原始的 Scaled-Dot Attention 形式,但在進(jìn)行 Attention 之前,用兩個(gè) 的矩陣 分別對(duì) 進(jìn)行投影,即變?yōu)?#xff1a;
這樣一來(lái), 就只是一個(gè) 的矩陣,而作者聲稱對(duì)于哪怕對(duì)于很大的序列長(zhǎng)度 n,m 也可以保持為一個(gè)適中的常數(shù),從而這種 Attention 也是線性的。
但是,筆者認(rèn)為“對(duì)于超長(zhǎng)序列 m 可以保持不變”這個(gè)結(jié)論是值得質(zhì)疑的,原論文中對(duì)于長(zhǎng)序列作者只做了 MLM 任務(wù),而很明顯 MLM 并不那么需要長(zhǎng)程依賴,所以這個(gè)實(shí)驗(yàn)沒(méi)什么說(shuō)服力。因此,Linformer 是不是真的 Linear,還有待商榷。
3.4 自回歸生成
Linformer 的另一個(gè)缺點(diǎn)是 這兩變直接把整個(gè)序列的信息給“糅合”起來(lái)了,所以它沒(méi)法簡(jiǎn)單地把將來(lái)信息給 Mask 掉(Causal Masking),從而無(wú)法做語(yǔ)言模型、Seq2Seq 等自回歸生成任務(wù),這也是剛才說(shuō)的原作者只做了 MLM 任務(wù)的原因。
相比之下,本文介紹的幾種 Linear Attention 都能做到這一點(diǎn)。以式(3)和式(4)為例,如果要 Mask 掉未來(lái)信息,那么只需要把求和 改為 :
實(shí)現(xiàn)上式有兩種方式:第一方式是設(shè) 以及 ,我們有:
這說(shuō)明這種 Attention 可以作為一個(gè) RNN 模型用遞歸的方式實(shí)現(xiàn),它的空間復(fù)雜度最低,但是要串性計(jì)算,適合預(yù)測(cè)解碼時(shí)使用;第二種是直接將 做外積,得到一個(gè) 的矩陣,然后對(duì) n 那一維執(zhí)行 運(yùn)算,這樣就一次性得到 了,它的速度最快,但空間占用最大,適合訓(xùn)練時(shí)使用。
3.5 下采樣技術(shù)
從結(jié)果上來(lái)看,Linformer 的 就是將序列變短(下采樣)了,而將序列變短的一個(gè)最樸素的方法就是 Pooling 了,所以筆者之前也嘗試過(guò)把 Pooling 技術(shù)引入到 Transformer 中去。
近來(lái)也有類似的工作發(fā)出來(lái),比如IBM的PoWER-BERT: Accelerating BERT Inference via Progressive Word-vector Elimination [8] 和 Google 的 Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing [9] 。
除了 Pooling 之外,其實(shí)還有其他的下采樣技術(shù),比如可以通過(guò) stride > 1 的一維卷積來(lái)實(shí)現(xiàn),基于這個(gè)思路,或許我們可以把 FFN 里邊的 Position-Wise 全連接換成 stride > 1 的一維卷積?總之這方面應(yīng)該也能玩出很多花樣來(lái),不過(guò)跟 Linformer 一樣,這樣糅合之后做自回歸生成就很難了。
文章小結(jié)
本文介紹了一些從結(jié)構(gòu)上對(duì) Attention 進(jìn)行修改從而降低其計(jì)算復(fù)雜度的工作,其中最主要的 idea 是去掉標(biāo)準(zhǔn) Attention 中的 Softmax,就可以使得 Attention 的復(fù)雜度退化為理想的 級(jí)別(Linear Attention)。
相比于其他類似的改進(jìn)結(jié)構(gòu)的工作,這種修改能在把復(fù)雜度降到 的同時(shí),依然保留所有的 “token-token” 的注意力,同時(shí)還能保留用于做自回歸生成的可能性。
參考鏈接
[1] https://arxiv.org/abs/2006.16236
[2] https://arxiv.org/abs/1706.03762
[3] https://kexue.fm/archives/7325
[4] https://kexue.fm/archives/1711.07971
[5] https://arxiv.org/abs/2006.16236
[6] https://arxiv.org/abs/1908.11775
[7] https://arxiv.org/abs/1812.01243
[8] https://arxiv.org/abs/2001.08950
[9] https://arxiv.org/abs/2006.03236
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
?????來(lái)稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來(lái)稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結(jié)
以上是生活随笔為你收集整理的线性Attention的探索:Attention必须有个Softmax吗?的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 10 月 1~7 日全社会跨区域人员流动
- 下一篇: 直播 | 旷视研究院最新理论成果:批归一