日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

MNIST手写数字识别

發(fā)布時間:2024/9/30 编程问答 21 豆豆
生活随笔 收集整理的這篇文章主要介紹了 MNIST手写数字识别 小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.

進入到研究生階段了,從頭學一下Pytorch,在這個小破站上記錄一下自己的學習過程。
本文使用的是Pytorch來做手寫數(shù)字的識別。

step0:先引入一些相關的包和庫

import torch from torch import nn from torch.nn import functional as F from torch import optim import torchvision from matplotlib import pyplot as pltfrom utils import plot_image,plot_curve,one_hot

這里的utils是定義的一些輔助工具,包括loss下降的繪圖函數(shù)和one_hot編碼及圖片顯示的輔助函數(shù)。代碼如下:
utils.py

# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/10/26 下午4:47import torch from matplotlib import pyplot as plt###loss下降 def plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img,label,name):fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()def one_hot(labels,depth=10):out = torch.zeros(labels.size(0),depth)idx = torch.LongTensor(labels).view(-1,1)out.scatter_(dim = 1, index = idx,value=1)return out

step1:加載數(shù)據(jù)
使用torch的DataLoader方法加載數(shù)據(jù),MNIST數(shù)據(jù)集中的圖片大小為28*28,比較小,batch_size可以設置大一點。

batch_size = 512 ###step1 load dataset train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),#數(shù)據(jù)歸一化torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True )test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size , shuffle = False )

transforms.Compose方法將數(shù)據(jù)轉為Tensor和做數(shù)據(jù)歸一化,訓練集中設置shuffle=True是將訓練數(shù)據(jù)打亂.

step2:定義網絡結構
使用簡單的三層線性模型來做簡單的識別。

class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self, x):#x:[batch_size,1,28,28]x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

step3:train
訓練3個epoch

train_loss = [] net =Net() optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)for epoch in range(3):for batch_idx,(x,y) in enumerate(train_loader):# x:[batch_size,1,28,28]#將x打平成二維的# y:batch_sizex = x.view(x.size(0),28*28)out = net(x)y_onehot = one_hot(y)##lose = mse(y,out)loss = F.mse_loss(out,y_onehot)optimizer.zero_grad() #梯度清零loss.backward() #計算梯度optimizer.step() #更新參數(shù)##打印losstrain_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item()) plot_curve(train_loss)

step4:test
最后在驗證集測試訓練的準確率

total_correct = 0 for x,y in test_loader:x = x.view(x.size(0),28*28)out = net(x)pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correct total_num = len(test_loader.dataset)acc = total_correct / total_num print("test acc:",acc)

總結

以上是生活随笔為你收集整理的MNIST手写数字识别的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。