PyTorch基础-猫狗分类实战-10
生活随笔
收集整理的這篇文章主要介紹了
PyTorch基础-猫狗分类实战-10
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
訓(xùn)練模型并保存
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets,transforms,models from torch.utils.data import Dataset import sys # 數(shù)據(jù)預(yù)處理 transform = transforms.Compose([transforms.RandomResizedCrop(224),# 對(duì)圖像進(jìn)行隨機(jī)裁剪transforms.RandomRotation(20),# 隨機(jī)旋轉(zhuǎn)角度transforms.RandomHorizontalFlip(p=0.5),# 隨機(jī)水平翻轉(zhuǎn)transforms.ToTensor()# 變成tensor格式 ]) # 數(shù)據(jù)增強(qiáng)# 讀取數(shù)據(jù) root = "image" train_dataset = datasets.ImageFolder(root + "/train",transform) test_dataset = datasets.ImageFolder(root + "/test",transform)# 導(dǎo)入數(shù)據(jù) train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=8,shuffle=True) classes = train_dataset.classes classes_index = train_dataset.class_to_idx print(classes) print(classes_index) model = models.vgg16(pretrained=True)# 載入vgg16預(yù)訓(xùn)練模型 print(model) for param in model.parameters():param.requires_grad = False # 構(gòu)建新的全連接層 model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2)) LR = 0.0003 # 定義代價(jià)函數(shù) entropy_loss = nn.CrossEntropyLoss() # 定義優(yōu)化器 optimizer = optim.Adam(model.parameters(),LR) def train():model.train()for i,data in enumerate(train_loader):# 獲得數(shù)據(jù)和對(duì)應(yīng)的標(biāo)簽inputs,labels = data# 獲得模型預(yù)測結(jié)果(64,10)out = model(inputs)# 交叉熵代價(jià)函數(shù)out(batch.C),labels(batch)loss = entropy_loss(out,labels)# 梯度清零optimizer.zero_grad()# 計(jì)算梯度loss.backward()# 修改權(quán)值optimizer.step()def test():model.eval()correct = 0for i,data in enumerate(test_loader):# 獲得數(shù)據(jù)和對(duì)應(yīng)的標(biāo)簽inputs,labels = data# 獲得模型預(yù)測結(jié)果out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預(yù)測正確的數(shù)量correct += (predicted == labels).sum()print("test acc:{0}".format(correct.item()/len(test_dataset)))correct = 0for i,data in enumerate(train_loader):# 獲得數(shù)據(jù)和對(duì)應(yīng)的標(biāo)簽inputs,labels = data# 獲得模型預(yù)測結(jié)果out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預(yù)測正確的數(shù)量correct += (predicted == labels).sum()print("train acc:{0}".format(correct.item()/len(train_dataset))) for epoch in range(5):print("epoch:",epoch)train()test() torch.save(model.state_dict(),"cat_dog.pth") # 保存模型加載模型進(jìn)行預(yù)測
import torch import numpy as np from PIL import Image from torchvision import transforms,models model = models.vgg16(pretrained=True) # 構(gòu)建新的全連接層 model.classifier = torch.nn.Sequential(torch.nn.Linear(25088,100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100,2)) model.load_state_dict(torch.load("cat_dog.pth")) # 加載模型 model.eval() # 預(yù)測模式 label = np.array(["cat","dog"]) # 數(shù)據(jù)預(yù)處理 transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor() ]) # 預(yù)測函數(shù) def predict(image_path):# 打開圖片img = Image.open(image_path)# 數(shù)據(jù)處理,增加一個(gè)維度img = transform(img).unsqueeze(0)# 預(yù)測得到的結(jié)果outputs = model(img)# 獲得最大值所在位置_,predicted = torch.max(outputs,1)# 轉(zhuǎn)換為類別名稱print(label[predicted.item()]) predict("image/test/cat/cat.1490.jpg")總結(jié)
以上是生活随笔為你收集整理的PyTorch基础-猫狗分类实战-10的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch基础-模型的保存和加载-0
- 下一篇: 机器学习基础-一元线性回归-01