55_pytorch,自定义数据集
1.55.自定義數據
1.55.1.數據傳遞機制
我們首先回顧識別手寫數字的程序:
... Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True) ... for epoch in range(EPOCH):for i, (image, label) in enumerate(dataloader):...從上面的程序,我們可以知道,在PyTorch中,數據傳遞機制是這樣的:
1.創建Dataset
2.Dataset傳遞給DataLoader
3.DataLoader迭代產生訓練數據提供給模型。
總結這個數據傳遞機制就是,Dataset負責建立索引到樣本的映射,DataLoader負責以特定的方式從數據集中迭代的產生一個個batch的樣本集合。在enumerate過程中實際上是dataloader按照其參數sampler規定的策略調用了其dataset的getitem方法(下文中將介紹該方法)。
在上面的識別手寫數字的例子中,數據集是直接下載的,但如果我們自己收集了一些數據,存在電腦文件夾里,我們該如何把這些數據變為可以在PyTorch框架下進行神經網絡訓練的數據集呢,即如何自定義數據集呢?
1.55.1.1.PyTorch中Dataset,DataLoader,Sample的關系
PyTorch中Dataset,DataLoader,Sampler的關系可以用下圖概括:
用文字表達就是:Dataloader中包含Sampler和Dataset,Sampler產生索引,Dataset拿著這個索引在數據集文件夾中找到對應的樣本(每個樣本對應一個索引,就像列表中每個元素對應一個索引),并給該樣本配置上標簽,最后返回(樣本+標簽)給調用方。
在enumerate過程中,Dataloader按照其參數BatchSampler規定的策略調用其Dataset的getitem方法batchsize次,得到一個batch,該batch中既包含樣本,也包含相應的標簽。
1.55.2.自定義數據集
torch.utils.data.Dataset 是一個表示數據集的抽象類。任何自定義的數據集都需要繼承這個類并覆寫相關方法。所謂數據集,其實就是一個負責處理索引(index)到樣本(sample)映射的一個類(class)。Pytorch提供兩種數據集: Map式數據集 Iterable式數據集。這里我們只介紹前者。
一個Map式的數據集必須要重寫getitem(self, index)、 len(self) 兩個內建方法,用來表示從索引到樣本的映射(Map)。這樣一個數據集dataset,舉個例子,當使用dataset[idx]命令時,可以在你的硬盤中讀取數據集中第idx張圖片以及其標簽(如果有的話); len(dataset)則會返回這個數據集的容量。
自定義數據集類的范式大致是這樣的:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).#這里需要注意的是,第一步:read one data,是一個data pointpassdef __len__(self):# You should change 0 to the total size of your dataset.return 0關于Dataset API的官網介紹https://pytorch.org/docs/stable/data.html#dataset-types:
Dataset類的使用:所有的類都應該是此類的子類(也就是說應該繼承該類)。所有的子類都要重寫(override) len(), getitem()。
?__len()__ : 此方法應該提供數據集的大小(容量)
?__getitem()__ : 此方法應該提供支持下標索引方式訪問數據集。
DataLoader類的使用如下:
根據這個方式,我們舉一個例子。
1.55.3.實例1
從kaggle官網下載dogsVScats的數據集(百度網盤下載鏈接見文末),該數據集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:
sampleSubmission.csv中的內容如下:
我們把其中前10000張貓的圖片和10000張狗的圖片作為訓練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗證集。貓的label記為0,狗的label記為1。因為圖片大小不一,所以,我們需要對圖像進行transform。
# -*- coding: UTF-8 -*-import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image""" 如果代碼執行的時候出現: OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized. OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.解決辦法是加上: import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" """ import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"image_transform = transforms.Compose([transforms.Resize(256), # 把圖片resize為256*256transforms.RandomCrop(224), # 隨機裁剪224*224transforms.RandomHorizontalFlip(), # 水平翻轉transforms.ToTensor(), # 將圖像轉為Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標準化 ])# 創建一個叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset class DogVsCatDataset(Dataset):def __init__(self, root_dir, train=True, transform=None):"""Args:root_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be applied on a sample."""self.root_dir = root_dirself.img_path = os.listdir(self.root_dir)if train:# 圖片數據中有類似:dog.12499.jpg的圖片共12499張。# x.split('.')[1] 就是文件名dog.12473.jpg中的序號部分,也是圖片的編號self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path)) # 劃分訓練集和驗證集else:# 序號大于10000的編號self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))self.transform = transformdef __len__(self):return len(self.img_path)def __getitem__(self, idx):image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1 # label, 貓為0,狗為1if self.transform:image = self.transform(image)label = torch.from_numpy(np.array([label]))return image, label# 來測試一下 if __name__ == '__main__':catanddog_dataset = DogVsCatDataset(root_dir='E:/BaiduNetdiskDownload/kaggle/train',train=False,transform=image_transform)# num_workers=4表示用4個線程讀取數據train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)# iter()函數把train_loader變為迭代器,然后調用迭代器的next()方法image, label = iter(train_loader).next()sample = image[0].squeeze()sample = sample.permute((1, 2, 0)).numpy()sample *= [0.229, 0.224, 0.225]sample += [0.485, 0.456, 0.406]sample = np.clip(sample, 0, 1)plt.imshow(sample)plt.show()print('Label is: {}'.format(label[0].numpy()))運行結果:
1.55.4.實例2
1.55.4.1.收集圖像樣本
以簡單的貓狗二分類為例,可以在網上下載一些貓狗圖片。創建以下目錄:
?data -----------------根目錄
?data/test -----------------測試集
?data/train -----------------訓練集
?data/val ------------------驗證集
在test/train/val之下在校分別創建2個文件夾,dog,cat
cat,dog文件夾下分別存放2類圖像:
之后寫一個簡單的python腳本,生成txt文件,用于指明每個圖像和標簽的對應關系。
格式:
/cat/1.jpg 0
/dog/1.jpg 1
…
如圖:
至此,樣本集的收集以及簡單歸類完成。
1.55.4.2.實現
使用到python package
| numpy | 矩陣操作,對圖像進行轉置 |
| skimage | 圖像處理,圖像I/O,圖像變換 |
| matplotlib | 圖像的顯示,可視化 |
| os | 一些文件查找操作 |
| torch | pytorch |
| torchvision | pytorch |
1.55.4.3.代碼
# -*- coding: UTF-8 -*-""" 本案例來自:https://www.jb51.net/article/199360.htm """import numpy as np from skimage import io from skimage import transform import matplotlib.pyplot as plt import os import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision.transforms import transforms from torchvision.utils import make_grid""" 第一步: 定義一個子類,繼承Dataset類,重寫__len()__,__getitem()__方法。 細節: 1、數據集中一個一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。 2、圖像的讀取:采用skimage.io進行讀取,讀取之后的結果為numpy.ndarray形式。 3、圖像變換:transform參數 """class MyDataset(Dataset):def __init__(self, root_dir, names_file, transform=None):self.root_dir = root_dirself.names_file = names_fileself.transform = transformself.size = 0self.names_list = []if not os.path.isfile(self.names_file):print(self.names_file + 'does not exist!')file = open(self.names_file)for f in file:self.names_list.append(f)self.size += 1def __len__(self):return self.sizedef __getitem__(self, idx):image_path = self.root_dir + self.names_list[idx].split(' ')[0]if not os.path.isfile(image_path):print(image_path + 'does not exists!')return Noneimage = io.imread(image_path) # use skitimagelabel = int(self.names_list[idx].split(' ')[1])sample = {'image': image, 'label': label}if self.transform:sample = self.transform(sample)return sample""" 第二步 實例化一個對象,并讀取和顯示數據集 """ train_dataset = MyDataset(root_dir='./data/train',names_file='./data/train/train.txt',transform=None)plt.figure() for (cnt, i) in enumerate(train_dataset):image = i['image']label = i['label']ax = plt.subplot(4, 4, cnt + 1)ax.axis('off')ax.imshow(image)ax.set_title('label {}'.format(label))plt.pause(0.001)if cnt == 15:break""" 第三步(可選optional) 對數據集進行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數據歸一化。另一方面,可 以通過變換進行數據增加data argument關于pytorch中的變換transforms,請參考該系列之前的文章。由于數據集中樣本采用字典dicts形式表示。 因此不能直接調用torchvision.transofrms中的方法。 本實驗只進行尺寸歸一化Resize, 數據類型變換ToTensor操作。Resize """# 變換Resize class Resize(object):def __init__(self, output_size: tuple):self.output_size = output_sizedef __call__(self, sample):# 圖像image = sample['image']# 使用skitimage.transform對圖像進行縮放image_new = transform.resize(image, self.output_size)return {'image': image_new, 'label': sample['label']}# ToTensor ## 變換ToTensor class ToTensor(object):def __call__(self, sample):image = sample['image']image_new = np.transpose(image, (2, 0, 1))return {'image': torch.from_numpy(image_new), 'label': sample['label']}""" 第四步:對整個數據集應用變換 細節:transformers.Compose()將不同的幾個組合起來。先進行Resize,再進行ToTensor """ # 對原始的訓練數據集進行變換 transformed_trainset = MyDataset(root_dir='./data/train',names_file='./data/train/train.txt',transform=transforms.Compose([Resize((224, 224)),ToTensor()]))""" 第五步:使用DataLoader進行包裝 為何要使用DataLoader? 1、深度學習的輸入是mini_batch形式 2、樣本加載時候可能需要隨機打亂順序,shuffle操作 3、樣本加載需要采用多線程 pytorch提供的DataLoader封裝了上述的功能,這樣使用起來更方便。 """ # 使用DataLoader可以利用多線程,batch,shuffle等 # 使用DataLoader可以利用多線程,batch,shuffle等 trainset_dataloader = DataLoader(dataset=transformed_trainset,batch_size=4,shuffle=True,num_workers=4)# 可視化 def show_images_batch(sample_batched):images_batch, labels_batch = \sample_batched['image'], sample_batched['label']grid = make_grid(images_batch)plt.imshow(grid.numpy().transpose(1, 2, 0))# sample_batch: Tensor , NxCxHxW plt.figure() for i_batch, sample_batch in enumerate(trainset_dataloader):show_images_batch(sample_batch)plt.axis('off')plt.ioff()plt.show()plt.show() """ 通過DataLoader包裝之后,樣本以min_batch形式輸出,而且進行了隨機打亂順序。至此,自定義數據集的完整流程已經實現,test, val集只需要改路徑即可。 """輸出類似:
補充:
更簡單的方法
上述繼承Dataset,重寫__len()__,__getitem()是通用的方法,過程相對繁瑣。對于簡單的分類數據集,pytorch中提供了更簡便的方式----ImageFolder。
如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。仍然以cat, dog二分類數據集為例:
文件結構:
Code
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets import matplotlib.pyplot as plt import numpy as np# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html# data_transform = transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ])data_transform = transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),])train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform) train_dataloader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True,num_workers=4)def show_batch_images(sample_batch):labels_batch = sample_batch[1]images_batch = sample_batch[0]for i in range(4):label_ = labels_batch[i].item()image_ = np.transpose(images_batch[i], (1, 2, 0))ax = plt.subplot(1, 4, i + 1)ax.imshow(image_)ax.set_title(str(label_))ax.axis('off')plt.pause(0.01)plt.figure() for i_batch, sample_batch in enumerate(train_dataloader):show_batch_images(sample_batch)plt.show()由于 train 目錄下只有2個文件夾,分別為cat, dog, 因此ImageFolder安裝順序對cat使用標簽0, dog使用標簽1。(輸出類似:)
1.55.5.參考文章
https://www.cnblogs.com/picassooo/p/12846617.html
https://www.jb51.net/article/199360.htm
總結
以上是生活随笔為你收集整理的55_pytorch,自定义数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 当兵五年还能回去上大学吗?
- 下一篇: flink报错:Error: Stati