PyTorch-图像分类演示
PyTorch圖像分類演示
簡(jiǎn)介
在之前的系列中提到了數(shù)據(jù)的加載與增強(qiáng)、模型的構(gòu)建、損失函數(shù)與優(yōu)化器的設(shè)計(jì)、訓(xùn)練的可視化,本文將以Caltech101圖像數(shù)據(jù)集為例,演示PyTorch的整個(gè)工作流程,以PyTorch作為工具進(jìn)行深度學(xué)習(xí)項(xiàng)目的大體思路就是本文所述。
數(shù)據(jù)準(zhǔn)備
這邊采用自定義Dataset的方法批量導(dǎo)入數(shù)據(jù)集并進(jìn)行相應(yīng)的數(shù)據(jù)增廣,這里采用的數(shù)據(jù)集是劃分好的,同時(shí)也是生成了desc的csv文件的,具體操作見(jiàn)數(shù)據(jù)準(zhǔn)備的博客。
核心代碼如下,具體整個(gè)訓(xùn)練代碼見(jiàn)文末Github。
class MyDataset(Dataset):def __init__(self, desc_file, transform=None):self.all_data = pd.read_csv(desc_file).valuesself.transform = transformdef __getitem__(self, index):img, label = self.all_data[index, 0], self.all_data[index, 1]img = Image.open(img).convert('RGB')if self.transform is not None:img = self.transform(img)return img, labeldef __len__(self):return len(self.all_data)desc_train = '../data/desc_train.csv' desc_valid = '../data/desc_valid.csv' desc_test = '../data/desc_test.csv'batch_size = 16 lr = 0.001 epochs = 10 norm_mean = [0.4948052, 0.48568845, 0.44682974] norm_std = [0.24580306, 0.24236229, 0.2603115]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std) # 按照imagenet標(biāo)準(zhǔn) ])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std) ])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std) ])train_data = MyDataset(desc_train, transform=train_transform) valid_data = MyDataset(desc_valid, transform=valid_transform) test_data = MyDataset(desc_test, transform=test_transform)# 構(gòu)建DataLoader train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size) test_loader = DataLoader(dataset=test_data, batch_size=batch_size)模型構(gòu)建
這里使用了一個(gè)簡(jiǎn)單的卷積分類模型,具體代碼如下。
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64*6*6, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 101)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 64*6*6)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):nn.init.normal_(m.weight.data, 0, 0.01)m.bias.data.zero_()損失及優(yōu)化
這里的思路也是比較基礎(chǔ)的采用交叉熵以及動(dòng)量SGD的方法,同時(shí)加了一個(gè)自動(dòng)的學(xué)習(xí)率衰減。
criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, dampening=0.1) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)模型訓(xùn)練
批量訓(xùn)練的方法,同時(shí)在驗(yàn)證集上進(jìn)行驗(yàn)證,結(jié)果可視化于TensorBoard。
for epoch in range(epochs):# 訓(xùn)練集訓(xùn)練train_loss = 0.0correct = 0.0total = 0.0for step, data in enumerate(train_loader):x, y = dataout = model(x)loss = criterion(out, y)optimizer.zero_grad()loss.backward()optimizer.step()_, pred = torch.max(out.data, 1)total += y.size(0)correct += (pred == y).squeeze().sum().numpy()train_loss += loss.item()if step % 100 == 0:print("epoch", epoch, "step", step, "loss", loss.item())train_acc = correct / total# 驗(yàn)證集驗(yàn)證valid_loss = 0.0correct = 0.0total = 0.0for step, data in enumerate(valid_loader):model.eval()x, y = dataout = model(x)out.detach_()loss = criterion(out, y)_, pred = torch.max(out.data, 1)valid_loss += loss.item()total += y.size(0)correct += (pred == y).squeeze().sum().numpy()valid_acc = correct / totalscheduler.step(valid_loss)writer.add_scalars('loss', {'train_loss': train_loss, 'valid_loss': valid_loss}, epoch)writer.add_scalars('accuracy', {'train_acc': train_acc, 'valid_acc': valid_acc}, epoch)最終訓(xùn)練的結(jié)果可視化如下,由于訓(xùn)練集較小,訓(xùn)練輪次較少,效果不是很明顯,但是可以看到模型還是正常收斂的。
模型保存
保存訓(xùn)練好的模型參數(shù)用于后續(xù)在測(cè)試集上使用,或者部署到其他機(jī)器上。
net_save_path = 'net_params.pkl' torch.save(model.state_dict(), net_save_path)補(bǔ)充說(shuō)明
本文主要演示了PyTorch進(jìn)行深度模型訓(xùn)練的整個(gè)流程,事實(shí)上,PyTorch的模型訓(xùn)練大致流程是不變的,很多固定的寫(xiě)法我會(huì)在下一篇文章中提到這些常用代碼。本文涉及到的所有代碼均可以在我的Github找到。歡迎star或者fork。
總結(jié)
以上是生活随笔為你收集整理的PyTorch-图像分类演示的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: PyTorch-训练
- 下一篇: PyTorch-运算加速