PyTorch【torchvision】
pytorch自發布以來,由于其便捷性,贏得了越來越多人的喜愛。
Pytorch有很多方便易用的包,今天要談的是torchvision包,它包括3個子包,分別是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分別是預定義好的數據集(比如MNIST、CIFAR10等)、預定義好的經典網絡結構(比如AlexNet、VGG、ResNet等)和預定義好的數據增強方法(比如Resize、ToTensor等)。這些方法可以直接調用,簡化我們建模的過程,也可以作為我們學習或構建新的模型的參考。
本文,我們講述的是models,且只談模型的加載。models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網絡結構,并且提供了預訓練模型,可以通過簡單調用來讀取網絡結構和預訓練模型。
模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models
官方文檔:https://pytorch.org/docs/master/torchvision/models.html
我將加載的方法簡單總結為以下四種:
1.直接加載預訓練模型
import torchvision.models as models resnet50 = models.resnet50(pretrained=True)?
這樣就導入了resnet50的預訓練模型了。
如果只需要網絡結構,不需要用預訓練模型的參數來初始化,那么就是:
model =torchvision.models.resnet50(pretrained=False)或者把resnet復制到自己的目錄下,新建個model文件夾
可以參考下面的貓狗大戰入門算法入門
https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch
2.修改某一層
?以resnet為例,默認的是ImageNet的1000類,比如我們要做二分類,分類貓和狗
resnet 第一層卷積的卷積核是7,我們可能想改成5,那么可以通過以下方法修改:
#未經試驗,修改需要有理論依據,計算featuremap維度使之匹配。
resnet.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False)3.加載部分預訓練模型
對于具體的任務,很難保證模型和公開的模型完全一樣,但是預訓練模型的參數確實有助于提高訓練的準確率,為了結合二者的優點,就需要我們加載部分預訓練模型。
方法二:
使用這種方法,將會保存模型的參數和結構信息。
(1)保存
torch.save (model, PATH)(2)恢復
model = torch.load(PATH)參考資料:
1. https://zhuanlan.zhihu.com/p/25980324
2. http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/
3.https://blog.csdn.net/weixin_41278720/article/details/80759933?
?
總結
以上是生活随笔為你收集整理的PyTorch【torchvision】的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 八皇后问题 思路
- 下一篇: 回头再学Asp.net系列--基础篇(六