手写字母数据集转换为.pickle文件
生活随笔
收集整理的這篇文章主要介紹了
手写字母数据集转换为.pickle文件
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
首先是數據集,我上傳了相關的資源,https://download.csdn.net/download/fanzonghao/10566701
? 轉換代碼如下:
import numpy as np import os import matplotlib.pyplot as plt import matplotlib.image as mpig import imageio import pickle """ 函數功能:將notMNIST_large和notMNIST_small的圖片生成對應的.pickle文件 """ def load_letter(folder,min_num_images,image_size):image_files=os.listdir(folder)print(folder)#定義存放圖片的numpy類型dataset=np.ndarray(shape=(len(image_files),image_size,image_size),dtype=np.float32)num_image=0for image in image_files:image_file=os.path.join(folder,image)try:image_data=(mpig.imread(image_file)-0.5)/1assert image_data.shape==(image_size,image_size)dataset[num_image,:,:]=image_datanum_image+=1except(IOError,ValueError)as e:print('could not read:',image_file,e,'skipping')#提示所需樣本數少if num_image<min_num_images:raise Exception('samples is few {}<{}'.format(num_image,min_num_images))dataset=dataset[0:num_image,:,:]#去掉沒能讀取圖片的列表print('full dataset tensor:',dataset.shape)print('Mean:',np.mean(dataset))print('Standard deviation:',np.std(dataset))return dataset """ 將訓練樣本和測試樣本圖片輸出為.pickle形式 """ def deal_data(base_dir,min_num_images):data_folders=[os.path.join(base_dir, i) for i in sorted(os.listdir(base_dir))]dataset_names=[]for folder in data_folders:set_filename=folder+'.pickle'dataset_names.append(set_filename)dataset=load_letter(folder, min_num_images, image_size=28)try:with open(set_filename,'wb') as f:pickle.dump(dataset,f,pickle.HIGHEST_PROTOCOL)except Exception as e:print('unable to save data',set_filename,e)print(dataset_names)return dataset_names """ 給定訓練和測試樣本路徑 """ def produce_train_test_pickle():train_dir = './data/notMNIST_large'deal_data(train_dir,min_num_images=45000)test_dir = './data/notMNIST_small'deal_data(test_dir,min_num_images=1800)# #測試 imageio模塊和matploab image的區別 # def test(): # image_path = os.path.join('./data/notMNIST_small', 'MDEtMDEtMDAudHRm.png') # # 用matploab image 讀出來的像素處于0~1之間 # image = mpig.imread(image_path) # print(image.shape) # print((image - 0.5) / 1) # plt.subplot(121) # plt.imshow(image) # # # 用imageio模塊 讀出來的像素處于0~255之間 # image = imageio.imread(image_path) # print(image.shape) # print((image - 255 / 2) / 255) # plt.subplot(122) # plt.imshow(image) # # plt.show() if __name__ == '__main__':# test()produce_train_test_pickle()# #結果保存輸出路徑 # output_path='./data/notMNIST_small/Pickles' # if not os.path.exists(output_path): # os.makedirs(output_path)?打印生成的結果:
將兩個數據集的手寫字母生成的.pickle轉換成整個.pickle數據集,這樣在使用的時候方便直接調用,代碼如下:
import numpy as np import data_deal import os import pickle """ 函數功能:功能1:調用把圖片文件生成pickle文件的功能2:通過把生成的pickle文件調用生成train_dataset和valid_dataset和test_dataset """ #生成.pickle文件 沒有的時候才執行 # data_deal.produce_train_test_pickle()""" 生成所需數據的np array """ def make_array(rows,img_size):if rows:dataset=np.ndarray(shape=(rows,img_size,img_size),dtype=np.float32)labels=np.ndarray(shape=(rows,),dtype=np.int32)else:dataset, labels=None,Nonereturn dataset,labels """ 生成訓練集和測試集 dataset """ def produce_train_test_datasets(pickle_files,train_size,valid_size=0):num_classes=len(pickle_files)valid_dataset,valid_lable=make_array(valid_size, img_size=28)train_dataset, train_lable = make_array(train_size, img_size=28)#小數據量存儲近train_dataset和valid_datasetvalid_size_per_class = valid_size // num_classestrain_size_per_class = train_size // num_classesstart_v,start_t=0,0end_v,end_t=valid_size_per_class,train_size_per_classend_l=valid_size_per_class+train_size_per_classfor lable,pickle_file in enumerate(pickle_files):with open(pickle_file,'rb') as f:#載入每個字母的pickleevery_letter_samples=pickle.load(f)#打亂順序 (7000,28,28)對下一層進行打亂操作 直接改變原有的順序np.random.shuffle(every_letter_samples)#制作驗證集if valid_dataset is not None:#放入test數據不需要valid_datasetvalid_letter=every_letter_samples[:valid_size_per_class,:,:]valid_dataset[start_v:end_v,:,:]=valid_lettervalid_lable[start_v:end_v]=lablestart_v+=valid_size_per_classend_v+=valid_size_per_class# 制作訓練集train_letter = every_letter_samples[valid_size_per_class:end_l, :, :]train_dataset[start_t:end_t, :, :] = train_lettertrain_lable[start_t:end_t] = lablestart_t += train_size_per_classend_t += train_size_per_classreturn valid_dataset,valid_lable,train_dataset,train_lable """ 實現訓練樣本 測試樣本的A~j順序打亂 """ def random_letter(dataset,labels):#獲取打亂的索引permutation=np.random.permutation(labels.shape[0])dataset=dataset[permutation,:,:]labels=labels[permutation]return dataset,labels """ 生成最終的notMNIST.pickle 包含train valid test """ def notMNIST_pickle():train_size=200000valid_size=1000test_size=1000train_dir = './data/notMNIST_large/Pickles'train_pickle_dir=[os.path.join(train_dir,i) for i in sorted(os.listdir(train_dir))]valid_dataset,valid_lable,train_dataset,train_lable=produce_train_test_datasets(train_pickle_dir,train_size,valid_size)test_dir = './data/notMNIST_small/Pickles'test_pickle_dir=[os.path.join(test_dir,i) for i in sorted(os.listdir(test_dir))]_,_,test_dataset,test_lable=produce_train_test_datasets(test_pickle_dir,test_size)print('Training',train_dataset.shape,train_lable.shape)print('Validing',valid_dataset.shape,valid_lable.shape)print('Testing',test_dataset.shape,test_lable.shape)train_dataset,train_label=random_letter(train_dataset,train_lable)valid_dataset, valid_label = random_letter(valid_dataset, valid_lable)test_dataset, test_label = random_letter(test_dataset, test_lable)print('after shuffle training',train_dataset.shape,train_label.shape)print('after shuffle validing',valid_dataset.shape,valid_label.shape)print('after shuffle testing',test_dataset.shape,test_label.shape)all_pickle_file=os.path.join('./data','notMNIST.pickle')try:with open(all_pickle_file, 'wb') as f:save={'train_dataset':train_dataset,'train_label': train_label,'valid_dataset': valid_dataset,'valid_label': valid_label,'test_dataset': test_dataset,'test_label': test_label,}pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)except Exception as e:print('unable to save data', all_pickle_file, e)statinfo=os.stat(all_pickle_file)print('Compressed pickle size',statinfo.st_size) if __name__ == '__main__':notMNIST_pickle()讀取.pickle
import tensorflow as tf import numpy as np import pickle import matplotlib.pyplot as plt #對于x變成(samles,pixs),y變成one_hot (samples,10) """ one-hot """ def reformat(dataset,labels,imgsize,C):dataset=dataset.reshape(-1,imgsize*imgsize).astype(np.float32)#one_hot兩種寫法#寫法一labels=np.eye(C)[labels.reshape(-1)].astype(np.float32)#寫法二#labels=(np.arange(10)==labels[:,None]).astype(np.float32)return dataset,labels """ 讀取.pickle文件 """ def pickle_dataset():path='./data/notMNIST.pickle'with open(path,'rb') as f:restore=pickle.load(f)train_dataset=restore['train_dataset']train_label = restore['train_label']valid_dataset = restore['valid_dataset']valid_label = restore['valid_label']test_dataset = restore['test_dataset']test_label = restore['test_label']del restore# print('Training:', train_dataset.shape, train_label.shape)# print('Validing:', valid_dataset.shape, valid_label.shape)# print('Testing:', test_dataset.shape, test_label.shape)train_dataset,train_label=reformat(train_dataset,train_label,imgsize=28,C=10)valid_dataset,valid_label=reformat(valid_dataset,valid_label,imgsize=28,C=10)test_dataset,test_label=reformat(test_dataset,test_label,imgsize=28,C=10)# print('after Training:', train_dataset.shape, train_label.shape)# print('after Validing:', valid_dataset.shape, valid_label.shape)# print('after Testing:', test_dataset.shape, test_label.shape)return train_dataset,train_label,valid_dataset,valid_label,test_dataset,test_label# #測試生成的數據正確不 # def test(train_dataset,train_label): # print(train_label[:10]) # #plt.figure(figsize=(50,20)) # for i in range(10): # plt.subplot(5,2,i+1) # plt.imshow(train_dataset[i].reshape(28,28)) # plt.show()# if __name__ == '__main__': # test(train_dataset,train_label)總結
以上是生活随笔為你收集整理的手写字母数据集转换为.pickle文件的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ❤『面试知识集锦100篇』3.mysql
- 下一篇: ReactNative环境配置