PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速
作者丨Nicolas
單位丨追一科技AI Lab研究員
研究方向丨信息抽取、機(jī)器閱讀理解
你想獲得雙倍訓(xùn)練速度的快感嗎??
你想讓你的顯卡內(nèi)存瞬間翻倍嗎??
如果告訴你只需要三行代碼即可實(shí)現(xiàn),你信不??
在這篇文章里,筆者會(huì)詳解一下混合精度計(jì)算(Mixed Precision),并介紹一款 NVIDIA 開發(fā)的基于 PyTorch 的混合精度訓(xùn)練加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代碼就能實(shí)現(xiàn)不同程度的混合精度加速,訓(xùn)練時(shí)間直接縮小一半。?
話不多說,直接先教你怎么用。
PyTorch實(shí)現(xiàn)
from?apex?import?amp model,?optimizer?=?amp.initialize(model,?optimizer,?opt_level="O1")?#?這里是“歐一”,不是“零一” with?amp.scale_loss(loss,?optimizer)?as?scaled_loss: scaled_loss.backward()對,就是這么簡單,如果你不愿意花時(shí)間深入了解,讀到這基本就可以直接使用起來了。
但是如果你希望對 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各種不明所以的“Nan”的同學(xué),可以接著讀下去,后面會(huì)有一些有趣的理論知識和筆者最近一個(gè)月使用 Apex 遇到的各種 bug,不過當(dāng)你深入理解并解決掉這些 bug 后,你就可以徹底擺脫“慢吞吞”的 FP32 啦。
理論部分
為了充分理解混合精度的原理,以及?API?的使用,先補(bǔ)充一點(diǎn)基礎(chǔ)的理論知識。
1.?什么是FP16?
半精度浮點(diǎn)數(shù)是一種計(jì)算機(jī)使用的二進(jìn)制浮點(diǎn)數(shù)數(shù)據(jù)類型,使用?2?字節(jié)(16?位)存儲(chǔ)。
▲?FP16和FP32表示的范圍和精度對比
?
其中,?sign?位表示正負(fù),?exponent?位表示指數(shù),?fraction?位表示的是分?jǐn)?shù)。其中當(dāng)指數(shù)為零的時(shí)候,下圖加號左邊為 0,其他情況為?1。
▲?FP16的表示范例
?
2.?為什么需要FP16?
在使用?FP16?之前,我想再贅述一下為什么我們使用?FP16。
減少顯存占用?現(xiàn)在模型越來越大,當(dāng)你使用?Bert?這一類的預(yù)訓(xùn)練模型時(shí),往往模型及模型計(jì)算就占去顯存的大半,當(dāng)想要使用更大的?Batch Size?的時(shí)候會(huì)顯得捉襟見肘。由于 FP16?的內(nèi)存占用只有?FP32?的一半,自然地就可以幫助訓(xùn)練過程節(jié)省一半的顯存空間。
加快訓(xùn)練和推斷的計(jì)算?與普通的空間時(shí)間?Trade-off?的加速方法不同,FP16?除了能節(jié)約內(nèi)存,還能同時(shí)節(jié)省模型的訓(xùn)練時(shí)間。在大部分的測試中,基于 FP16?的加速方法能夠給模型訓(xùn)練帶來多一倍的加速體驗(yàn)(爽感類似于兩倍速看肥皂劇)。
張量核心的普及?硬件的發(fā)展同樣也推動(dòng)著模型計(jì)算的加速,隨著?NVIDIA?張量核心(Tensor Core)的普及,16bit?計(jì)算也一步步走向成熟,低精度計(jì)算也是未來深度學(xué)習(xí)的一個(gè)重要趨勢,再不學(xué)習(xí)就?out?啦。
?
3. FP16帶來的問題:量化誤差
這個(gè)部分是整篇文章最重要的理論核心。
?
講了這么多?FP16?的好處,那么使用?FP16?的時(shí)候有沒有什么問題呢?當(dāng)然有。FP16 帶來的問題主要有兩個(gè):1.?溢出錯(cuò)誤;2.?舍入誤差。
?
溢出錯(cuò)誤(Grad Overflow / Underflow)由于?FP16?的動(dòng)態(tài)范圍比?FP32?的動(dòng)態(tài)范圍要狹窄很多,因此在計(jì)算過程中很容易出現(xiàn)上溢出(Overflow,g>65504)和下溢出(Underflow,)的錯(cuò)誤,溢出之后就會(huì)出現(xiàn)“Nan”的問題。
在深度學(xué)習(xí)中,由于激活函數(shù)的的梯度往往要比權(quán)重梯度小,更易出現(xiàn)下溢出的情況。
▲?下溢出問題
舍入誤差(Rounding Error)舍入誤差指的是當(dāng)梯度過小,小于當(dāng)前區(qū)間內(nèi)的最小間隔時(shí),該次梯度更新可能會(huì)失敗,用一張圖清晰地表示:
▲?舍入誤差
?
4.?解決問題的辦法:混合精度訓(xùn)練+動(dòng)態(tài)損失放大
混合精度訓(xùn)練(Mixed Precision)混合精度訓(xùn)練的精髓在于“在內(nèi)存中用?FP16?做儲(chǔ)存和乘法從而加速計(jì)算,用?FP32?做累加避免舍入誤差”。混合精度訓(xùn)練的策略有效地緩解了舍入誤差的問題。
損失放大(Loss Scaling)即使用了混合精度訓(xùn)練,還是會(huì)存在無法收斂的情況,原因是激活梯度的值太小,造成了下溢出(Underflow)。損失放大的思路是:
反向傳播前,將損失變化(dLoss)手動(dòng)增大倍,因此反向傳播時(shí)得到的中間變量(激活函數(shù)梯度)則不會(huì)溢出;
反向傳播后,將權(quán)重梯度縮小倍,恢復(fù)正常值。
Apex的新API:Automatic Mixed Precision (AMP)
曾經(jīng)的 Apex 混合精度訓(xùn)練的 API 仍然需要手動(dòng) half 模型以及輸入的數(shù)據(jù),比較麻煩,現(xiàn)在新的 API 只需要三行代碼即可無痛使用:
from?apex?import?amp model,?optimizer?=?amp.initialize(model,?optimizer,?opt_level="O1")?#?這里是“歐一”,不是“零一” with?amp.scale_loss(loss,?optimizer)?as?scaled_loss: scaled_loss.backward()opt_level?
其中只有一個(gè) opt_level 需要用戶自行配置:?
?O0 :純 FP32 訓(xùn)練,可以作為 accuracy 的 baseline;?
?O1?:混合精度訓(xùn)練(推薦使用),根據(jù)黑白名單自動(dòng)決定使用 FP16(GEMM, 卷積)還是 FP32(Softmax)進(jìn)行計(jì)算;?
?O2 :“幾乎 FP16”混合精度訓(xùn)練,不存在黑白名單,除了 Batch Norm,幾乎都是用 FP16 計(jì)算;
?O3 :純 FP16 訓(xùn)練,很不穩(wěn)定,但是可以作為 speed 的 baseline。
動(dòng)態(tài)損失放大(Dynamic Loss Scaling)?
AMP 默認(rèn)使用動(dòng)態(tài)損失放大,為了充分利用 FP16 的范圍,緩解舍入誤差,盡量使用最高的放大倍數(shù)(),如果產(chǎn)生了上溢出(Overflow),則跳過參數(shù)更新,縮小放大倍數(shù)使其不溢出,在一定步數(shù)后(比如 2000 步)會(huì)再嘗試使用大的 scale 來充分利用 FP16 的范圍:
▲?AMP中動(dòng)態(tài)損失放大的策略
干貨:踩過的那些坑
這一部分是整篇文章最干貨的部分,是筆者在最近在 apex 使用中的踩過的所有的坑,由于 apex 報(bào)錯(cuò)并不明顯,常常 debug 得讓人很沮喪,但只要注意到以下的點(diǎn),95% 的情況都可以暢通無阻了:?
1. 判斷你的 GPU 是否支持 FP16:支持的有擁有 Tensor Core 的 GPU(2080Ti、Titan、Tesla 等),不支持的(Pascal 系列)就不建議折騰了;
2. 常數(shù)的范圍:為了保證計(jì)算不溢出,首先要保證人為設(shè)定的常數(shù)(包括調(diào)用的源碼中的)不溢出,如各種 epsilon,INF 等;
3. Dimension 最好是 8 的倍數(shù):NVIDIA 官方的文檔的 2.2 條 [2] 表示,維度都是 8 的倍數(shù)的時(shí)候,性能最好;
4. 涉及到 sum 的操作要小心,很容易溢出,類似 Softmax 的操作建議用官方 API,并定義成 layer 寫在模型初始化里;
5. 模型書寫要規(guī)范:自定義的 Layer 寫在模型初始化函數(shù)里,graph 計(jì)算寫在 forward 里;
6. 某些不常用的函數(shù),在使用前需要注冊:?amp.register_float_function(torch, 'sigmoid') ;
7. 某些函數(shù)(如 einsum)暫不支持 FP16 加速,建議不要用的太 heavy,XLNet 的實(shí)現(xiàn)改 FP16 [4] 困擾了我很久;
8. 需要操作模型參數(shù)的模塊(類似 EMA),要使用 AMP 封裝后的 model;
9. 需要操作梯度的模塊必須在 optimizer 的 step 里,不然 AMP 不能判斷 grad 是否為 Nan;
10. 歡迎補(bǔ)充。
總結(jié)
這篇從理論到實(shí)踐地介紹了混合精度計(jì)算以及 Apex 新 API(AMP)的使用方法。筆者現(xiàn)在在做深度學(xué)習(xí)模型的時(shí)候,幾乎都會(huì)第一時(shí)間把代碼改成混合精度訓(xùn)練的了,速度快,精度還不減,確實(shí)是調(diào)參煉丹必備神器。目前網(wǎng)上還并沒有看到關(guān)于 AMP 以及使用時(shí)會(huì)遇到的坑的中文博客,所以這一篇也是希望大家在使用的時(shí)候可以少花一點(diǎn)時(shí)間 debug。當(dāng)然,如果讀者們有發(fā)現(xiàn)新的坑歡迎交流,我會(huì)補(bǔ)充在專欄的博客 [5] 中。
Reference
[1] Intel的低精度表示用于深度學(xué)習(xí)訓(xùn)練與推斷?
http://market.itcgb.com/Contents/Intel/OR_AI_BJ/images/Brian_DeepLearning_LowNumericalPrecision.pdf
[2] NVIDIA官方混合精度訓(xùn)練文檔
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html?
[3] Apex官方使用文檔
https://nvidia.github.io/apex/amp.html
[4]?XLNet的實(shí)現(xiàn)改FP16
https://github.com/NVIDIA/apex/issues/394?
[5] 專欄博客
https://zhuanlan.zhihu.com/p/79887894
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
ICCV 2019 | 基于持續(xù)學(xué)習(xí)的條件圖像生成模型
@即將開學(xué)的你,請收好這份必讀論文清單
小米AutoML團(tuán)隊(duì)發(fā)布可伸縮超網(wǎng)SCARLET
Github大熱論文 | 基于GAN的新型無監(jiān)督圖像轉(zhuǎn)換
后BERT時(shí)代的那些NLP預(yù)訓(xùn)練模型
KDD Cup 2019 AutoML Track冠軍團(tuán)隊(duì)技術(shù)分享
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識的人。
總有一些你不認(rèn)識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識真正流動(dòng)起來。
??來稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨(dú)在附件中發(fā)送?
? 請留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們在編輯發(fā)布時(shí)和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 獲取最新論文推薦
總結(jié)
以上是生活随笔為你收集整理的PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 北京 | 免费高效训练及OpenVINO
- 下一篇: “让Keras更酷一些!”:层中层与ma