PyTorch-混合精度训练
簡介
自動混合精度訓練(auto Mixed Precision,amp)是深度學習比較流行的一個訓練技巧,它可以大幅度降低訓練的成本并提高訓練的速度,因此在競賽中受到了較多的關注。此前,比較流行的混合精度訓練工具是由NVIDIA開發的A PyTorch Extension(Apex),它能夠以非常簡單的API支持自動混合精度訓練,不過,PyTorch從1.6版本開始已經內置了amp模塊,本文簡單介紹其使用。
自動混合精度(AMP)
首先來聊聊自動混合精度的由來。下圖是常見的浮點數表示形式,它表示單精度浮點數,在編程語言中的體現是float型,顯然從圖中不難看出它需要4個byte也就是32bit來進行存儲。深度學習的模型數據均采用float32進行表示,這就帶來了兩個問題:模型size大,對顯存要求高;32位計算慢,導致模型訓練和推理速度慢。
那么半精度是什么呢,顧名思義,它只用16位即2byte來進行表示,較小的存儲占用以及較快的運算速度可以緩解上面32位浮點數的兩個主要問題,因此半精度會帶來下面的一些優勢:
那么,半精度有沒有什么問題呢?其實也是有著很致命的問題的,主要是移除錯誤和舍入誤差兩個方面,具體可以參考這篇文章,作者解析的很好,我這里就簡單復述一下。
溢出錯誤
FP16的數值表示范圍比FP32的表示范圍小很多,因此在計算過程中很容易出現上溢出(overflow)和下溢出(underflow)問題,溢出后會出現梯度nan問題,導致模型無法正確更新,嚴重影響網絡的收斂。而且,深度模型訓練,由于激活函數的梯度往往比權重的梯度要小,更容易出現的是下溢出問題。
舍入誤差
舍入誤差(Rounding Error)指的是當梯度過小,小于當前區間內的最小間隔時,該次梯度更新可能會失敗。上面說的知乎文章的作者用來一張很形象的圖進行解釋,具體如下,意思是說在2?32^{-3}2?3到2?22^{-2}2?2之間,2?32^{-3}2?3每次變大都會至少加上2?132^{-13}2?13,顯然,梯度還在這個間隔內,因此更新是失敗的。
那么這兩個問題是如何解決的呢,思路來自于NVIDIA和百度合作的論文,我這里簡述一下方法:混合精度訓練和損失縮放。前者的思路是在內存中使用FP16做儲存和乘法運算以加速計算,用FP32做累加運算以避免舍入誤差,這樣就緩解了舍入誤差的問題;后者則是針對梯度值太小從而下溢出的問題,它的思想是:反向傳播前,將損失變化手動增大2k2^k2k倍,因此反向傳播時得到的中間變量(激活函數梯度)則不會溢出;反向傳播后,將權重梯度縮小2k2^k2k倍,恢復正常值。
研究人員通過引入FP32進行混合精度訓練以及通過損失縮放來解決FP16的不足,從而實現了一套混合精度訓練的范式,NVIDIA以此為基礎設計了Apex包,不過Apex的使用本文就不涉及了,下一節主要關注如何使用torch.cuda.amp實現自動混合精度訓練,不過這里還需要補充的一點就是目前混合精度訓練支持的N卡只有包含Tensor Core的卡,如2080Ti、Titan、Tesla等。
PyTorch自動混合精度
PyTorch對混合精度的支持始于1.6版本,位于torch.cuda.amp模塊下,主要是torch.cuda.amp.autocast和torch.cuda.amp.GradScale兩個模塊,autocast針對選定的代碼塊自動選取適合的計算精度,以便在保持模型準確率的情況下最大化改善訓練效率;GradScaler通過梯度縮放,以最大程度避免使用FP16進行運算時的梯度下溢。官方給的使用這兩個模塊進行自動精度訓練的示例代碼鏈接給出,我對其示例解析如下,這就是一般的訓練框架。
# 以默認精度創建模型和優化器 model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...)# 創建梯度縮放器 scaler = GradScaler()for epoch in epochs:for input, target in data:optimizer.zero_grad()# 通過自動類型轉換進行前向傳播with autocast():output = model(input)loss = loss_fn(output, target)# 縮放大損失,反向傳播不建議放到autocast下,它默認和前向采用相同的計算精度scaler.scale(loss).backward()# 先反縮放梯度,若反縮后梯度不是inf或者nan,則用于權重更新scaler.step(optimizer)# 更新縮放器scaler.update()下面我以簡單的MNIST任務做測試,使用的顯卡為RTX 3090,代碼如下。該代碼段中只包含核心的訓練模塊,模型的定義和數據集的加載熟悉PyTorch的應該不難自行補充。
model = Model() model = model.cuda() loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters())n_epochs = 30 start = time.time() for epoch in range(n_epochs):total_loss, correct, total = 0.0, 0, 0model.train()for step, data in enumerate(data_loader_train):x_train, y_train = datax_train, y_train = x_train.cuda(), y_train.cuda()outputs = model(x_train)_, pred = torch.max(outputs, 1)loss = loss_fn(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()total += len(y_train)correct += torch.sum(pred == y_train).item()print("epoch {} loss {} acc {}".format(epoch, total_loss, correct / total))我這里采用的是一個很小的模型,又是一個很簡單的任務,因此模型都是很快收斂,因此精度上沒有什么明顯的區別,不過如果是訓練大型模型的話,有人已經用實驗證明,內置amp和apex庫都會有精度下降,不過amp效果更好一些,下降較少。上面的loss變化圖也是非常類似的。
再來看存儲方面,顯存縮減在這個任務中的表現不是特別明顯,因為這個任務的參數量不多,前后向過程中的FP16存儲節省不明顯,而因為引入了一些拷貝之類的,反而使得顯存略有上升,實際的任務中,這種開銷肯定遠小于FP32的開銷的。
最后,不妨看一下使用混合精度最關心的速度問題,實際上混合精度確實會帶來一些速度上的優勢,一些官方的大模型如BERT等訓練速度提高了2-3倍,這對于工業界的需求來說,啟發還是比較多的。
總結
混合精度計算是未來深度學習發展的重要方向,很受工業界的關注,PyTorch從1.6版本開始默認支持amp,雖然現在還不是特別完善,但以后一定會越來越好,因此熟悉自動混合精度的用法還是有必要的。
超強干貨來襲 云風專訪:近40年碼齡,通宵達旦的技術人生總結
以上是生活随笔為你收集整理的PyTorch-混合精度训练的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2020已去,2021未来
- 下一篇: FcaNet解读