pytorch的dataset用法详解
生活随笔
收集整理的這篇文章主要介紹了
pytorch的dataset用法详解
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
torch.utils.data 里面的dataset使用方法
當我們繼承了一個 Dataset類之后,我們需要重寫 len 方法,該方法提供了dataset的大小; getitem 方法, 該方法支持從 0 到 len(self)的索引
from torch.utils.data import Dataset, DataLoader import torchclass MyDataset(Dataset):"""下載數據、初始化數據,都可以在這里完成"""def __init__(self):self.x = torch.linspace(11,20,10)self.y = torch.linspace(1,10,10)self.len = len(self.x)def __getitem__(self, index):return self.x[index], self.y[index]def __len__(self):return self.len# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。 mydataset = MyDataset()#[return: # # (tensor(x1),tensor(y1)); # # (tensor(x2),tensor(y2)); # # ......train_loader2 = DataLoader(dataset=mydataset,batch_size=5,shuffle=False)for epoch in range(3): # 訓練所有!整套!數據 3 次for step,(batch_x,batch_y) in enumerate(train_loader2): # 每一步 loader 釋放一小批數據用來學習#return:#(tensor(x1,x2,x3,x4,x5),tensor(y1,y2,y3,y4,y5))#(tensor(x6,x7,x8,x9,x10),tensor(y6,y7,y8,y9,y10))# 假設這里就是你訓練的地方...# 打出來一些數據print('Epoch: ', epoch, '| Step:', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())torchvision.datasets的使用方法
torchvision中datasets中所有封裝的數據集都是torch.utils.data.Dataset的子類,它們都實現了__getitem__和__len__方法。因此,它們都可以用torch.utils.data.DataLoader進行數據加載。
用法1:使用官方數據集
可選數據集參考:https://www.pianshen.com/article/9695297328/
代碼:torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())root (string): 表示數據集的根目錄,其中根目錄存在CIFAR10/processed/training.pt和CIFAR10/processed/test.pt的子目錄 train (bool, optional): 如果為True,則從training.pt創建數據集,否則從test.pt創建數據集 download (bool, optional): 如果為True,則從internet下載數據集并將其放入根目錄。如果數據集已下載,則不會再次下載 transform (callable, optional): 接收PIL圖片并返回轉換后版本圖片的轉換函數 target_transform (callable, optional): 接收PIL接收目標并對其進行變換的轉換函數 import torchvision# 準備的測試數據集 from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)# 測試數據集中第一張圖片及target img, target = test_data[0] print(img.shape) print(target)writer = SummaryWriter("dataloader") for epoch in range(2):step = 0for data in test_loader:imgs, targets = data# print(imgs.shape)# print(targets)writer.add_images("Epoch: {}".format(epoch), imgs, step)step = step + 1writer.close()用法2:ImageFolder通用的自己數據集加載器
一個通用的數據加載器,數據集中的數據以以下方式組織
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.pngroot/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])ImageFolder有以下成員變量:
- self.classes - 用一個list保存 類名
- self.class_to_idx - 類名對應的 索引
- self.imgs - 保存(img-path, class) tuple的list
該方法可以結合 torch.utils.data.Subset使用 ,以根據示例索引將您的ImageFolder數據集分為訓練和測試。
orig_set = torchvision.datasets.Imagefolder('dataset/') # your dataset n = len(orig_set) # total number of examples n_test = int(0.1 * n) # take ~10% for test test_set = torch.utils.data.Subset(orig_set, range(n_test)) # take first 10% train_set = torch.utils.data.Subset(orig_set, range(n_test, n)) # take the rest總結
以上是生活随笔為你收集整理的pytorch的dataset用法详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 软测第二周作业WordCount
- 下一篇: http 请求头 header Refe