pytorch保存模型pth_Day159:模型的保存与加载
生活随笔
收集整理的這篇文章主要介紹了
pytorch保存模型pth_Day159:模型的保存与加载
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
網絡結構和參數可以分開的保存和加載,因此,pytorch保存模型有兩種方法:
- 注意到,兩者都是用torch.save(obj, dir)實現,這個函數的作用是將對象保存到磁盤中,它的內部是使用Python的pickle實現
- PyTorch約定使用.pt或.pth后綴命名保存文件
- 兩種方法的區別其實就是obj參數的不同:前者的obj是整個model對象,后者的obj是從model對象里獲取存儲了model參數的詞典,推薦用第二種,雖然麻煩了一丁點,但是比較靈活,有利于實現預訓練、參數遷移等操作
一般加載模型是在訓練完成后用模型做測試,這時候加載模型記得要加上model.eval(),把模型切換到evaluation模式,這時候會調整dropout和bactch的模式。
- 網絡結構及其參數的保存與加載:load整個模型,完成了模型的定義和參數的加載這兩個過程
- 只保存/加載模型參數:需要先創建一個網絡模型,然后再load_state_dict()
重點介紹一下這種方法,一般訓完一個模型之后不會只保存一個模型的參數,為了方便后續操作,比如恢復訓練、參數遷移等,會保存當前狀態的一個快照,格式以字典的格式存儲,具體信息可以根據自己的需要,下面列出幾個方面:
- 模型參數(不帶模型的結構)
- 優化器參數
- loss
- epoch
- args
把這些信息用字典包裝起來,然后保存即可。這種方式保存的只是參數,所以,在加載時需要先創建好模型,然后再把參數加載進去,如下:
# 獲得保存信息save_data = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'epoch': epoch, 'args': args ...}# 保存torch.save(save_data , path)# 加載參數model_CKPT = torch.load(path)model = Mymodel()optimizer = Myoptimizer()model.load_state_dict(model_CKPT ['model_state_dict'])optimizer.load_state_dict(model_CKPT ['optimizer_state_dict'])...# 若對于加載參數,用函數表示,比如:def load_checkpoint(model, checkpoint_path, optimizer): if checkpoint_path != None: model_CKPT = torch.load(checkpoint_path) model.load_state_dict(model_CKPT['state_dict']) print('loading checkpoint!') optimizer.load_state_dict(model_CKPT['optimizer']) return model, optimizer但是,對于已經保存好的模型參數,我們可能修改了一部分網絡結構,比如加了一些,刪除一些等等,那么需要過濾這些參數,加載方式如下:
def load_checkpoint(model, checkpoint_path, optimizer, loadOptimizer): if checkpoint_path != 'None': print("loading checkpoint...") model_dict = model.state_dict()# 修改后的模型隨機初始化的參數 modelCheckpoint = torch.load(checkpoint_path) # 修改前的模型參數 pretrained_dict = modelCheckpoint['model_state_dict'] # 過濾操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} # 獲取修改后模型所需參數 model_dict.update(new_dict) # 打印出來,更新了多少參數 print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))# 修改后模型加載所需的,已經訓練好的參數 model.load_state_dict(model_dict) print("loaded finished!") # 如果不需要更新優化器那么設置為false if loadOptimizer == True: optimizer.load_state_dict(modelCheckpoint['optimizer_state_dict']) print('loaded! optimizer') else: print('not loaded optimizer') else: print('No checkpoint_path is included') return model, optimizer參考1:https://blog.csdn.net/MoreAction_/article/details/107967053
參考2:https://zhuanlan.zhihu.com/p/38056115
總結
以上是生活随笔為你收集整理的pytorch保存模型pth_Day159:模型的保存与加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: adxl276怎么添加到proteus中
- 下一篇: spark读取hdfs路径下的数据_到底