TensorFlow构建模型(图片数据加载)六
概要
本文內容來源于TensorFlow教程
本文主要介紹了三種圖片數據的加載和預處理方法:
內容
import numpy as np import os import PIL import PIL.Image import tensorflow as tf import tensorflow_datasets as tfds import pathlib# 下載花的數據集 dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" data_dir = tf.keras.utils.get_file(origin=dataset_url,fname='flower_photos',untar=True) data_dir = pathlib.Path(data_dir) os.listdir(data_dir) # ['LICENSE.txt', 'tulips', 'roses', 'dandelion', 'daisy', 'sunflowers'] image_count = len(list(data_dir.glob('*/*.jpg'))) print(image_count) # 3670數據集的目錄格式:
flowers_photos/
????????? daisy/
??????????dandelion/
????????? roses/
????????? sunflowers/
????????? tulips/
使用tf.keras.utils.image_dataset_from_directory將圖片數據集加載存入內存
這里注意,我們在使用tf.keras.utils.image_dataset_from_directory加載數據的時候使用image_size參數重新定義了圖片的大小。這個步驟也可以定義在模型中,通過使用tf.keras.layers.Resizing。
大數據集的情況數據加載有可能會成為模型訓練的瓶頸,可以通過以下兩種方法使用緩存的方式加載數據:
以上是使用tf.keras.utils.image_dataset_from_directory加載數據的方法,下面使用tf.data更好的控制數據輸入,通過使用tf.data編寫自己的數據輸入通道。
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False) list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)for f in list_ds.take(5):print(f.numpy())# 使用文件的樹結構生成類別組數 class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])) print(class_names) # 劃分訓練集和驗證集 val_size = int(image_count * 0.2) train_ds = list_ds.skip(val_size) val_ds = list_ds.take(val_size)print(tf.data.experimental.cardinality(train_ds).numpy()) print(tf.data.experimental.cardinality(val_ds).numpy())# 轉換文件路徑成(img, label)對 def get_label(file_path):# Convert the path to a list of path componentsparts = tf.strings.split(file_path, os.path.sep)# The second to last is the class-directoryone_hot = parts[-2] == class_names# Integer encode the labelreturn tf.argmax(one_hot)def decode_img(img):# Convert the compressed string to a 3D uint8 tensorimg = tf.io.decode_jpeg(img, channels=3)# Resize the image to the desired sizereturn tf.image.resize(img, [img_height, img_width])def process_path(file_path):label = get_label(file_path)# Load the raw data from the file as a stringimg = tf.io.read_file(file_path)img = decode_img(img)return img, label# 使用`Dataset.map`創建一個image,label對數據集 # Set `num_parallel_calls` so multiple images are loaded/processed in parallel. train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE) val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)for image, label in train_ds.take(1):print("Image shape: ", image.numpy().shape)print("Label: ", label.numpy())為了性能配置數據集。
def configure_for_performance(ds):ds = ds.cache() # 緩存ds = ds.shuffle(buffer_size=1000) # 打亂數據ds = ds.batch(batch_size) # 批處理ds = ds.prefetch(buffer_size=AUTOTUNE) # 保證批量數據盡快可用return dstrain_ds = configure_for_performance(train_ds) val_ds = configure_for_performance(val_ds)# 可視化數據 image_batch, label_batch = next(iter(train_ds))plt.figure(figsize=(10, 10)) for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(image_batch[i].numpy().astype("uint8"))label = label_batch[i]plt.title(class_names[label])plt.axis("off")使用TensorFlow數據集
(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers',split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],with_info=True,as_supervised=True, )num_classes = metadata.features['label'].num_classes print(num_classes)get_label_name = metadata.features['label'].int2strimage, label = next(iter(train_ds)) _ = plt.imshow(image) _ = plt.title(get_label_name(label))train_ds = configure_for_performance(train_ds) val_ds = configure_for_performance(val_ds) test_ds = configure_for_performance(test_ds)為了完整性,我們構建了一個卷積網絡訓練模型。三個帶最大池化的卷積層,一個全連接層
num_classes = 5model = tf.keras.Sequential([tf.keras.layers.Rescaling(1./255),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(num_classes) ])model.compile(optimizer='adam',loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.fit(train_ds,validation_data=val_ds,epochs=3 )我們也可以自己寫一個訓練循環器替代model.fit,詳情參考:從頭編寫訓練循環
總結
以上是生活随笔為你收集整理的TensorFlow构建模型(图片数据加载)六的全部內容,希望文章能夠幫你解決所遇到的問題。