深度学习pytorch--MNIST数据集
圖像分類數(shù)據(jù)集(Fashion-MNIST)
在介紹softmax回歸的實(shí)現(xiàn)前我們先引入一個(gè)多類圖像分類數(shù)據(jù)集。它將在后面的章節(jié)中被多次使用,以方便我們觀察比較算法之間在模型精度和計(jì)算效率上的區(qū)別。圖像分類數(shù)據(jù)集中最常用的是手寫數(shù)字識(shí)別數(shù)據(jù)集MNIST[1]。但大部分模型在MNIST上的分類精度都超過了95%。為了更直觀地觀察算法之間的差異,我們將使用一個(gè)圖像內(nèi)容更加復(fù)雜的數(shù)據(jù)集Fashion-MNIST[2](這個(gè)數(shù)據(jù)集也比較小,只有幾十M,沒有GPU的電腦也能吃得消)。
本節(jié)我們將使用torchvision包,它是服務(wù)于PyTorch深度學(xué)習(xí)框架的,主要用來構(gòu)建計(jì)算機(jī)視覺模型。torchvision主要由以下幾部分構(gòu)成:
獲取數(shù)據(jù)集
首先導(dǎo)入本節(jié)需要的包或模塊。
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys sys.path.append("..") # 為了導(dǎo)入上層目錄的d2lzh_pytorch import d2lzh_pytorch as d2l下面,我們通過torchvision的torchvision.datasets來下載這個(gè)數(shù)據(jù)集。第一次調(diào)用時(shí)會(huì)自動(dòng)從網(wǎng)上獲取數(shù)據(jù)。我們通過參數(shù)train來指定獲取訓(xùn)練數(shù)據(jù)集或測試數(shù)據(jù)集(testing data set)。測試數(shù)據(jù)集也叫測試集(testing set),只用來評(píng)價(jià)模型的表現(xiàn),并不用來訓(xùn)練模型。
另外我們還指定了參數(shù)transform = transforms.ToTensor()使所有數(shù)據(jù)轉(zhuǎn)換為Tensor,如果不進(jìn)行轉(zhuǎn)換則返回的是PIL圖片。transforms.ToTensor()將尺寸為 (H x W x C) 且數(shù)據(jù)位于[0, 255]的PIL圖片或者數(shù)據(jù)類型為np.uint8的NumPy數(shù)組轉(zhuǎn)換為尺寸為(C x H x W)且數(shù)據(jù)類型為torch.float32且位于[0.0, 1.0]的Tensor。
注意: 由于像素值為0到255的整數(shù),所以剛好是uint8所能表示的范圍,包括transforms.ToTensor()在內(nèi)的一些關(guān)于圖片的函數(shù)就默認(rèn)輸入的是uint8型,若不是,可能不會(huì)報(bào)錯(cuò)但可能得不到想要的結(jié)果。所以,如果用像素值(0-255整數(shù))表示圖片數(shù)據(jù),那么一律將其類型設(shè)置成uint8,避免不必要的bug。 本人就被這點(diǎn)坑過,詳見這個(gè)博客2.2.4節(jié)。
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())上面的mnist_train和mnist_test都是torch.utils.data.Dataset的子類,所以我們可以用len()來獲取該數(shù)據(jù)集的大小,還可以用下標(biāo)來獲取具體的一個(gè)樣本。訓(xùn)練集中和測試集中的每個(gè)類別的圖像數(shù)分別為6,000和1,000。因?yàn)橛?0個(gè)類別,所以訓(xùn)練集和測試集的樣本數(shù)分別為60,000和10,000。
print(type(mnist_train)) print(len(mnist_train), len(mnist_test))輸出:
<class 'torchvision.datasets.mnist.FashionMNIST'> 60000 10000我們可以通過下標(biāo)來訪問任意一個(gè)樣本:
feature, label = mnist_train[0] print(feature.shape, label) # Channel x Height x Width輸出:
torch.Size([1, 28, 28]) tensor(9)變量feature對(duì)應(yīng)高和寬均為28像素的圖像。由于我們使用了transforms.ToTensor(),所以每個(gè)像素的數(shù)值為[0.0, 1.0]的32位浮點(diǎn)數(shù)。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一維是通道數(shù),因?yàn)閿?shù)據(jù)集中是灰度圖像,所以通道數(shù)為1。后面兩維分別是圖像的高和寬。
Fashion-MNIST中一共包括了10個(gè)類別,分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運(yùn)動(dòng)鞋)、bag(包)和ankle boot(短靴)。以下函數(shù)可以將數(shù)值標(biāo)簽轉(zhuǎn)成相應(yīng)的文本標(biāo)簽。
# 本函數(shù)已保存在d2lzh包中方便以后使用 def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]下面定義一個(gè)可以在一行里畫出多張圖像和對(duì)應(yīng)標(biāo)簽的函數(shù)。
# 本函數(shù)已保存在d2lzh包中方便以后使用 def show_fashion_mnist(images, labels):d2l.use_svg_display()# 這里的_表示我們忽略(不使用)的變量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()現(xiàn)在,我們看一下訓(xùn)練數(shù)據(jù)集中前10個(gè)樣本的圖像內(nèi)容和文本標(biāo)簽。
X, y = [], [] for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))讀取小批量
我們將在訓(xùn)練數(shù)據(jù)集上訓(xùn)練模型,并將訓(xùn)練好的模型在測試數(shù)據(jù)集上評(píng)價(jià)模型的表現(xiàn)。前面說過,mnist_train是torch.utils.data.Dataset的子類,所以我們可以將其傳入torch.utils.data.DataLoader來創(chuàng)建一個(gè)讀取小批量數(shù)據(jù)樣本的DataLoader實(shí)例。
在實(shí)踐中,數(shù)據(jù)讀取經(jīng)常是訓(xùn)練的性能瓶頸,特別當(dāng)模型較簡單或者計(jì)算硬件性能較高時(shí)。PyTorch的DataLoader中一個(gè)很方便的功能是允許使用多進(jìn)程來加速數(shù)據(jù)讀取。這里我們通過參數(shù)num_workers來設(shè)置4個(gè)進(jìn)程讀取數(shù)據(jù)。
batch_size = 256 if sys.platform.startswith('win'):num_workers = 0 # 0表示不用額外的進(jìn)程來加速讀取數(shù)據(jù) else:num_workers = 4 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)我們將獲取并讀取Fashion-MNIST數(shù)據(jù)集的邏輯封裝在d2lzh_pytorch.load_data_fashion_mnist函數(shù)中供后面章節(jié)調(diào)用。該函數(shù)將返回train_iter和test_iter兩個(gè)變量。隨著本書內(nèi)容的不斷深入,我們會(huì)進(jìn)一步改進(jìn)該函數(shù)。
最后我們查看讀取一遍訓(xùn)練數(shù)據(jù)需要的時(shí)間。
start = time.time() for X, y in train_iter:continue print('%.2f sec' % (time.time() - start))輸出:
1.57 sec小結(jié)
- Fashion-MNIST是一個(gè)10類服飾分類數(shù)據(jù)集,之后章節(jié)里將使用它來檢驗(yàn)不同算法的表現(xiàn)。
- 我們將高和寬分別為hhh和www像素的圖像的形狀記為h×wh \times wh×w或(h,w)。
參考文獻(xiàn)
[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/
[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.
注:本節(jié)除了代碼之外與原書基本相同,原書傳送門
轉(zhuǎn)載至:動(dòng)手學(xué)習(xí)深度學(xué)習(xí)pytorch
總結(jié)
以上是生活随笔為你收集整理的深度学习pytorch--MNIST数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习pytorch--softmax
- 下一篇: 深度学习pytorch--softmax