把显存用在刀刃上!17 种 pytorch 节约显存技巧
引導
- 1. 顯存都用在哪兒了?
- 2. 技巧 1:使用就地操作
- 3. 技巧 2:避免中間變量
- 4. 技巧 3:優化網絡模型
- 5. 技巧 4:減小 BATCH_SIZE
- 6. 技巧 5:拆分 BATCH
- 7. 技巧 6:降低 PATCH_SIZE
- 8. 技巧 7:優化損失求和
- 9. 技巧 8:調整訓練精度
- 10. 技巧 9:分割訓練過程
- 11. 技巧10:清理內存垃圾
- 12. 技巧11:使用梯度累積
- 13. 技巧12:清除不必要梯度
- 14. 技巧13:周期清理顯存
- 15. 技巧14:多使用下采樣
- 16. 技巧15:刪除無用變量
- 17. 技巧16:改變優化器
- 18. 終極技巧
1. 顯存都用在哪兒了?
一般在訓練神經網絡時,顯存主要被網絡模型和中間變量占用。
- 網絡模型中的卷積層,全連接層和標準化層等的參數占用顯存,而諸如激活層和池化層等本質上是不占用顯存的。
- 中間變量包括特征圖和優化器等,是消耗顯存最多的部分。
- 其實 pytorch 本身也占用一些顯存的,但占用不多,以下方法大致按照推薦的優先順序。
2. 技巧 1:使用就地操作
就地操作 (inplace) 字面理解就是在原地對變量進行操作,對應到 pytorch 中就是在原內存上對變量進行操作而不申請新的內存空間,從而減少對內存的使用。具體來說就地操作包括三個方面的實現途徑:
- 使用將 inplace 屬性定義為 True 的激活函數,如 nn.ReLU(inplace=True)
- 使用 pytorch 帶有就地操作的方法,一般是方法名后跟一個下劃線 “_”,如 tensor.add_(),tensor.scatter_(),F.relu_()
- 使用就地操作的運算符,如 y += x,y *= x
3. 技巧 2:避免中間變量
在自定義網絡結構的成員方法 forward 函數里,避免使用不必要的中間變量,盡量在之前已申請的內存里進行操作,比如下面的代碼就使用太多中間變量,占用大量不必要的顯存:
def forward(self, x):x0 = self.conv0(x) # 輸入層x1 = F.relu_(self.conv1(x0) + x0)x2 = F.relu_(self.conv2(x1) + x1)x3 = F.relu_(self.conv3(x2) + x2)x4 = F.relu_(self.conv4(x3) + x3)x5 = F.relu_(self.conv5(x4) + x4)x6 = self.conv(x5) # 輸出層return x6為了減少顯存占用,可以將上述 forward 函數修改如下:
def forward(self, x):x = self.conv0(x) # 輸入層x = F.relu_(self.conv1(x) + x)x = F.relu_(self.conv2(x) + x)x = F.relu_(self.conv3(x) + x)x = F.relu_(self.conv4(x) + x)x = F.relu_(self.conv5(x) + x)x = self.conv(x) # 輸出層return x上述兩段代碼實現的功能是一樣的,但對顯存的占用卻相去甚遠,后者能節省前者占用顯存的接近 90% 之多。
4. 技巧 3:優化網絡模型
網絡模型對顯存的占用主要指的就是卷積層,全連接層和標準化層等的參數,具體優化途徑包括但不限于:
- 減少卷積核數量 (=減少輸出特征圖通道數)
- 不使用全連接層
- 全局池化 nn.AdaptiveAvgPool2d() 代替全連接層 nn.Linear()
- 不使用標準化層
- 跳躍連接跨度不要太大太多 (避免產生大量中間變量)
5. 技巧 4:減小 BATCH_SIZE
- 在訓練卷積神經網絡時,epoch 代表的是數據整體進行訓練的次數,batch 代表將一個 epoch 拆分為 batch_size 批來參與訓練。
- 減小 batch_size 是一個減小顯存占用的慣用技巧,在訓練時顯存不夠一般優先減小 batch_size ,但 batch_size 不能無限變小,太大會導致網絡不穩定,太小會導致網絡不收斂。
6. 技巧 5:拆分 BATCH
拆分 batch 跟技巧 4 中減小 batch_size 本質是不一樣的, 這種拆分 batch 的操作可以理解為將兩次訓練的損失相加再反向傳播,但減小 batch_size 的操作是訓練一次反向傳播一次。拆分 batch 操作可以理解為三個步驟,假設原來 batch 的大小 batch_size=64:
- 將 batch 拆分為兩個 batch_size=32 的小 batch
- 分別輸入網絡與目標值計算損失,將得到的損失相加
- 進行反向傳播
7. 技巧 6:降低 PATCH_SIZE
- 在卷積神經網絡訓練中,patch_size 指的是輸入神經網絡的圖像大小,即(H*W)。
- 網絡輸入 patch 的大小對于后續特征圖的大小等影響非常大,訓練時可能采用諸如 [64*64],[128*128] 等大小的 patch,如果顯存不足可以進一步縮小 patch 的大小,比如 [32*32],[16*16]。
- 但這種方法存在問題,可能極大地影響網絡的泛化能力,在裁剪的時候一定要注意在原圖上隨機裁剪,一般不建議。
8. 技巧 7:優化損失求和
一個 batch 訓練結束會得到相應的一個損失值,如果要計算一個 epoch 的損失就需要累加之前產生的所有 batch 損失,但之前的 batch 損失在 GPU 中占用顯存,直接累加得到的 epoch 損失也會在 GPU 中占用顯存,可以通過如下方法進行優化:
epoch_loss += batch_loss.detach().item() # epoch 損失上邊代碼的效果就是首先解除 batch_loss 張量的 GPU 占用,將張量中的數據取出再進行累加。
9. 技巧 8:調整訓練精度
- 降低訓練精度
pytorch 中訓練神經網絡時浮點數默認使用 32 位浮點型數據,在訓練對于精度要求不是很高的網絡時可以改為 16 位浮點型數據進行訓練,但要注意同時將數據和網絡模型都轉為 16 位浮點型數據,否則會報錯。降低浮點型數據的操作實現過程非常簡單,但如果優化器選擇 Adam 時可能會報錯,選擇 SGD 優化器則不會報錯,具體操作步驟如下:
- 混合精度訓練
混合精度訓練指的是用 GPU 訓練網絡時,相關數據在內存中用半精度做儲存和乘法來加速計算,用全精度進行累加避免舍入誤差,這種混合經度訓練的方法可以令訓練時間減少一半左右,也可以很大程度上減小顯存占用。在 pytorch1.6 之前多使用 NVIDIA 提供的 apex 庫進行訓練,之后多使用 pytorch 自帶的 amp 庫,實例代碼如下:
10. 技巧 9:分割訓練過程
- 如果訓練的網絡非常深,比如 resnet101 就是一個很深的網絡,直接訓練深度神經網絡對顯存的要求非常高,一般一次無法直接訓練整個網絡。在這種情況下,可以將復雜網絡分割為兩個小網絡,分別進行訓練。
- checkpoint 是 pytorch 中一種用時間換空間的顯存不足解決方案,這種方法本質上減少的是參與一次訓練網絡整體的參數量,如下是一個實例代碼。
- 使用 checkpoint 進行網絡訓練要求輸入屬性 requires_grad=True ,在給出的代碼中將一個網絡結構拆分為 3 個子網絡進行訓練,對于沒有 nn.Sequential() 構建神經網絡的情況無非就是自定義的子網絡里多幾項,或者像例子中一樣單獨構建網絡塊。
- 對于由 nn.Sequential() 包含的大網絡塊 (小網絡塊時沒必要),可以使用 checkpoint_sequential 包來簡化實現,具體實現過程如下:
11. 技巧10:清理內存垃圾
- python 中定義的變量一般在使用結束時不會立即釋放資源,在訓練循環開始時可以利用如下代碼來回收內存垃圾。
12. 技巧11:使用梯度累積
- 由于顯存大小的限制,訓練大型網絡模型時無法使用較大的 batch_size ,而一般較大的 batch_size 能令網絡模型更快收斂。
- 梯度累積就是將多個 batch 計算得到的損失平均后累積再進行反向傳播,類似于技巧 5 中拆分 batch 的思想(但技巧 5 是將大 batch 拆小,訓練的依舊是大 batch,而梯度累積訓練的是小 batch)。
- 可以采用梯度累積的思想來模擬較大 batch_size 可以達到的效果,具體實現代碼如下:
13. 技巧12:清除不必要梯度
在運行測試程序時不涉及到與梯度有關的操作,因此可以清楚不必要的梯度以節約顯存,具體包括但不限于如下操作:
- 用代碼 model.eval() 將模型置于測試狀態,不啟用標準化和隨機舍棄神經元等操作。
- 測試代碼放入上下文管理器 with torch.no_grad(): 中,不進行圖構建等操作。
- 在訓練或測試每次循環開始時加梯度清零操作
14. 技巧13:周期清理顯存
- 同理也可以在訓練每次循環開始時利用 pytorch 自帶清理顯存的代碼來釋放不用的顯存資源。
執行這條語句釋放的顯存資源在用 Nvidia-smi 命令查看時體現不出,但確實是已經釋放。其實 pytorch 原則上是如果變量不再被引用會自動釋放,所以這條語句可能沒啥用,但個人覺得多少有點用。
15. 技巧14:多使用下采樣
下采樣從實現上來看類似池化,但不限于池化,其實也可以用步長大于 1 來代替池化等操作來進行下采樣。從結果上來看就是通過下采樣得到的特征圖會縮小,特征圖縮小自然參數量減少,進而節約顯存,可以用如下兩種方式實現:
nn.Conv2d(32, 32, 3, 2, 1) # 步長大于 1 下采樣nn.Conv2d(32, 32, 3, 1, 1) # 卷積核接池化下采樣 nn.MaxPool2d(2, 2)16. 技巧15:刪除無用變量
del 功能是徹底刪除一個變量,要再使用必須重新創建,注意 del 刪除的是一個變量而不是從內存中刪除一個數據,這個數據有可能也被別的變量在引用,實現方法很簡單,比如:
def forward(self, x):input_ = xx = F.relu_(self.conv1(x) + input_)x = F.relu_(self.conv2(x) + input_)x = F.relu_(self.conv3(x) + input_)del input_ # 刪除變量 input_x = self.conv4(x) # 輸出層return x17. 技巧16:改變優化器
進行網絡訓練時比較常用的優化器是 SGD 和 Adam,拋開訓練最后的效果來談,SGD 對于顯存的占用相比 Adam 而言是比較小的,實在沒有辦法時可以嘗試改變參數優化算法,兩種優化算法的調用是相似的:
import torch.optim as optim from torchvision.models import resnet18LEARNING_RATE = 1e-3 # 學習率 myNet = resnet18().cuda() # 實例化網絡optimizer_adam = optim.Adam(myNet.parameters(), lr=LEAENING_RATE) # adam 網絡參數優化算法 optimizer_sgd = optim.SGD(myNet.parameters(), lr=LEAENING_RATE) # sgd 網絡參數優化算法18. 終極技巧
購買顯存夠大的顯卡,一塊不行那就 多來幾塊。
總結
以上是生活随笔為你收集整理的把显存用在刀刃上!17 种 pytorch 节约显存技巧的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: jquery实现注册表单验证
- 下一篇: GitLab Admin Area