PyTorch-数据准备
PyTorch數據準備
簡介
沒有數據,所有的深度學習和機器學習都是無稽之談,本文通過Caltech101圖片數據集介紹PyTorch如何處理數據(包括數據的讀入、預處理、增強等操作)。
數據集構建
本文使用比較經典的Caltech101數據集,共含有101個類別,如下圖,其中BACKGROUND_Google為雜項,無法分類,使用該數據集時刪除該文件夾即可。
對數據集進行劃分,形成如下格式,劃分為訓練集、驗證集和測試集,每一種數據集中每個類別按照8:1:1進行數據劃分,具體代碼見scripts/dataset_split.py。
數據劃分完成后就要制作相關的數據集說明文件,在很多大型的數據集中經??吹竭@種文件且一般是csv格式的文件,該文件一般存放所有圖片的路徑及其標簽。生成了三個說明文件如下,圖中示例的是訓練集的說明文件。這部分的具體代碼見scripts/generate_desc.py。
PyTorch數據讀取API
上面構造了比較標準的數據集格式和主流的數據集說明文件,那么PyTorch如何識別這種格式的數據集呢?事實上,PyTorch對于數據集導入進行了封裝,其API為torch.utils.data中的Dataset類,只要繼承自該類即可自定義數據集。
需要注意的是Dataset類的最主要的重載方法為__getitem__方法,該方法需要傳入一個index的list(索引的列表),根據該列表去取出多個數據元素,每個元素是一個樣本(這里指圖片和標簽)。
下面的代碼構建了一個最簡單的Dataset,事實上訓練時需要定義很多預處理和增強方法,這里略過。
from torch.utils.data import Dataset import pandas as pd from PIL import Imageclass MyDataset(Dataset):def __init__(self, desc_file, transform=None):self.all_data = pd.read_csv(desc_file).valuesself.transform = transformdef __getitem__(self, index):img, label = self.all_data[index, 0], self.all_data[index, 1]img = Image.open(img).convert('RGB')if self.transform is not None:img = self.transform(img)return img, labeldef __len__(self):return len(self.all_data)if __name__ == '__main__':ds_train = MyDataset('../data/desc_train.csv', None)在成功構建這個Dataset之后就是將這個Dataset對象交給Dataloader,實例化的Dataloader會調用Dataset對象的__getitem__方法讀取一張圖片的數據和標簽并拼接為一個batch返回,作為模型的輸入。
下面示例講解如何在PyTorch中讀取數據集。
實驗效果如下,可以看到,按批次取出了數據和標簽。
其代碼如下。
from my_dataset import MyDataset from torch.utils.data import DataLoader from torchvision.transforms import transformsdesc_train = '../data/desc_train.csv' desc_valid = '../data/desc_valid.csv' batch_size = 16 lr_init = 0.001 epochs = 10normMean = [0.4948052, 0.48568845, 0.44682974] normStd = [0.24580306, 0.24236229, 0.2603115] train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize(normMean, normStd) # 按照imagenet標準 ])valid_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(normMean, normStd) ])# 構建MyDataset實例 train_data = MyDataset(desc_train, transform=train_transform) valid_data = MyDataset(desc_valid, transform=valid_transform)# 構建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size)for epoch in range(epochs):for step, data in enumerate(train_loader):inputs, labels = dataprint("epoch", epoch, "step", step)print("data shape", inputs.shape, labels.shape)數據增強
在實際使用中,數據進入模型之前會進行一些對應的預處理,如數據中心化、數據標準化,隨機裁減、圖片旋轉、鏡像翻轉,PyTorch預先定義了諸多數據增強方法,這些都放在torchvision的transforms模塊下。官方文檔只是羅列了相關API,沒有按照邏輯整理,其實數據增強主要有四大類。
- 裁減
- 轉動
- 圖像變換
- transform操作
下面逐一介紹。
裁減(Crop)
- 隨機裁減
- transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
- 根據給定的size進行隨機裁減。
- size參數為裁減后的圖像尺寸,為元組或者整型數,若為元組要符合(height, width)格式,若int型數則自動按照(size, size)進行長寬設定。
- padding參數設置需要填充多少個像素,為元組或者整形數,若為整形數則圖像上下左右填充padding個像素,若兩個數字的元組則第一個數字表示左右第二個數字表示上下,若四個數字則分別表示上下左右填充的像素個數。
- fill參數設置填充的值是什么,為整型或者元組只有padding_mode為constant時有效,當int時,各個通道填充該值,三個數的元組時RGB三通道分別填充。
- padding_mode參數設置填充模式,constant表示常量填充,edge表示按照圖片邊緣像素值填充,reflect根據反射進行填充,symmetric同前一個的效果但會重復邊上的值。
- 下圖左側為原圖,右側為隨機裁減的結果。尤其提醒,一般為了保證隨機裁減的區域合理,會先對圖片resize到一個合適的尺寸,再添加一個小的padding進行裁減,這樣可以保證裁減得到的圖片仍是圖像中主體內容。
- 中心裁減
- transforms.CenterCrop(size)
- 根據給定的size從中心進行裁減。
- size參數設置裁減后的圖片大小,為整型或者元組,若為元組要符合(height, width)格式,若int型數則自動按照(size, size)進行長寬設定。
- 下圖左側為原圖,右側為隨機裁減的結果。
- 隨機長寬比裁減
- transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.333), interpolation=2)
- 隨機大小且隨機長寬比裁減圖片,最后resize到指定的大小。
- size設置輸出圖片的大小,為整型或者元組,可以是(h, w)或者(size, size)格式。
- scale設置隨機裁減的大小區間,為浮點型,(a, b)表示裁減后會在a倍到b倍大小之間。
- ratio設置隨機長寬比,浮點型。
- interpolation設置插值方法,默認為雙線性插值。
- 上下左右中心裁減
- transforms.FiveCrop(size)
- 對圖片進行上下左右中心進行裁減,得到5個結果圖片,返回一個4D的Tensor。
- size參數同上。
- 上下左右中心裁減并翻轉
- ransforms.TenCrop(size, vertical_flip=False)
- 對圖片進行上下左右中心裁減后并翻轉,共得到10張圖片,返回一個4D的Tensor。
- size參數同上。
- vertical_flip設置是否進行垂直翻轉,布爾型,默認水平翻轉。
轉動(Flip and Rotation)
- 概率水平翻轉
- transforms.RandomHorizontalFlip(p)
- 以概率p對圖片進行水平翻轉。
- p設置進行翻轉的概率,浮點型,默認0.5。
- 下圖左側為原圖,右側為隨機裁減的結果。
- 概率垂直翻轉
- transforms.RandomVerticalFlip(size)
- 以概率p對圖片進行垂直翻轉。
- 參數同上。
- 隨機旋轉
- transforms.RandomRotation(degrees, resample=None, expand=False, center=None, fill=None)
- 按照degree度數進行隨機旋轉。
- degrees設置旋轉最大角度,數值型或者元組,若單數值則默認(-degrees, degree)之間隨機旋轉,若兩個數字的元組則表示(a, b)區間進行隨機旋轉。
- resample設置重采樣方法,有下面三個選項PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC。
- expand設置是否擴展圖片以容納旋轉結果,布爾型,默認False。
- center設置是否中心旋轉,默認左上角旋轉。
- fill設置填充方法,同裁減部分的參數。
- 下圖右側為隨機旋轉的結果圖。
圖像變換
- 圖片大小調整
- transforms.Resize(size, interpolation=2)
- 按照指定size調整圖片大小。
- size設置圖片大小,(h, w)格式,若為int數,則按照size*height/width, size進行調整。
- interpolation設置插值方法,默認PIL.Image.BILINEAR。
- 圖像標準化
- transforms.Normalize(mean, std)
- 對圖像按通道進行標準化,即減均值后除以標準差,輸入圖片(h, w, c)格式。
- 轉為Tensor
- transforms.ToTensor()
- 將PIL Image或者Numpy.ndarray轉為PyTorch的Tensor類型,并歸一化至[0-1],注意是直接除以255。
- 填充
- transforms.Pad(padding, fill=0, padding_mode='constant')
- 對圖像進行填充。
- 參數同上面的RandomCrop。
- 亮度、對比度和飽和度
- transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
- 修改亮度、對比度和飽和度。
- 灰度化
- transforms.Grayscale(num_output_channels=1)
- 圖片轉為灰度圖。
- num_output_channels,輸出的灰度圖通道數,默認為1,為3時RGB三通道值相同。
- 線性變換
- transforms.LinearTransformation(transformation_matrix)
- 對矩陣做線性變換,可用于白化處理。
- 仿射變換
- transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)
- 概率灰度化
- transforms.RandomGrayscale(p)
- 以概率p進行灰度化。
- 轉換為PIL Image
- transforms.ToPILImage(mode)
- 將PyTorch的Tensor或者Numpy的ndarray轉為PIL的Image。
- mode參數設置準換后的圖片格式,mode為None時單通道,mode為3時轉為RGB,為4時轉為RGBA。
- Lambda
- transforms.Lambda(func)
- 將自定義的圖像變換函數應用到其中。
- func是一個lambda函數。
transform操作
PyTorch不只是能對圖像進行變換,還能對這些變換進行隨機選擇和組合。
- transforms.RandomChoice(transforms)
- 隨機選擇某一個變換進行作用。
- transforms.RandomApply(transforms, p=0.5)
- 以概率p進行變幻的選擇。
- transforms.RandomOrder(transforms)
- 隨機打亂變換的順序。
補充說明
這個系列的PyTorch教程我沒有按照之前TensorFlow2系列教程那樣進行由淺入深的展開,即從基礎張量運算API到模型訓練以及主流的深度模型構建這樣的思路,而是按照深度方法端到端的思路展開的,即數據準備、模型構建。損失及優化、訓練可視化這樣的思路。這是因為PyTorch的基礎運算操作類似于Numpy和TensorFlow,我在之前的TensorFlow教程中以及介紹過了。
本文涉及的項目代碼均可以在我的Github找到,歡迎star或者fork。因為篇幅限制,較為簡略,如有疏漏,歡迎指出。
總結
以上是生活随笔為你收集整理的PyTorch-数据准备的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 新的一年2020
- 下一篇: PyTorch-模型