import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim# prepare datasetbatch_size =64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))])# 歸一化,均值和方差train_dataset = datasets.MNIST(root='E:\\tmp\\pytorch/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='E:\\tmp\\pytorch/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using classclassNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784,512)self.l2 = torch.nn.Linear(512,256)self.l3 = torch.nn.Linear(256,128)self.l4 = torch.nn.Linear(128,64)self.l5 = torch.nn.Linear(64,10)defforward(self, x):x = x.view(-1,784)# -1其實就是自動獲取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)# 最后一層不做激活,不進行非線性變換model = Net()# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedeftrain(epoch):running_loss =0.0for batch_idx, data inenumerate(train_loader,0):# 獲得一個批次的數據和標簽inputs, target = dataoptimizer.zero_grad()# 獲得模型預測結果(64, 10)outputs = model(inputs)# 交叉熵代價函數outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx %300==299:print('[%d, %5d] loss: %.3f'%(epoch+1, batch_idx+1, running_loss/300))running_loss =0.0deftest():correct =0total =0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)# dim = 1 列是第0個維度,行是第1個維度total += labels.size(0)correct +=(predicted == labels).sum().item()# 張量之間的比較運算print('accuracy on test set: %d %% '%(100*correct/total))if __name__ =='__main__':for epoch inrange(10):train(epoch)test()
結果如下:
[1,300] loss:2.186[1,600] loss:0.881[1,900] loss:0.435
accuracy on test set:89%[2,300] loss:0.323[2,600] loss:0.262[2,900] loss:0.225
accuracy on test set:94%[3,300] loss:0.187[3,600] loss:0.171[3,900] loss:0.152
accuracy on test set:95%[4,300] loss:0.127[4,600] loss:0.127[4,900] loss:0.110
accuracy on test set:96%[5,300] loss:0.097[5,600] loss:0.096[5,900] loss:0.092
accuracy on test set:96%[6,300] loss:0.080[6,600] loss:0.074[6,900] loss:0.073
accuracy on test set:97%[7,300] loss:0.062[7,600] loss:0.059[7,900] loss:0.062
accuracy on test set:97%[8,300] loss:0.052[8,600] loss:0.048[8,900] loss:0.048
accuracy on test set:97%[9,300] loss:0.039[9,600] loss:0.042[9,900] loss:0.038
accuracy on test set:96%[10,300] loss:0.034[10,600] loss:0.032[10,900] loss:0.033
accuracy on test set:97%