TFRecord 的写入和读取(序列化和反序列化)
文章目錄
- 寫入/序列化
- 讀取/反序列化
- tf.train.Feature
- 具體實現
TFrecord 是 tensorflow 中的數據集存儲格式,它的寫入和讀取相當于序列化和反序列化的過程。
下面內容都是基于 tf2 版本來說明,兩個版本中 tfrecord 的核心并沒有改變,即序列化和反序列化。但是個人認為 tf2 里面對于讀取 tfrecords 文件建立 Dataset 直接用于訓練的支持更方便好用了,比如 batch、shuffle、repeat 等方面,所以在這點上基本摒棄了 tf1。
寫入/序列化
(1)將數據讀到內存,并轉換為 tf.train.Example 對象,每個對象由若干個 tf.train.Feature 的字典組成。
(2)將 tf.train.Example 對象序列化為字符串,寫入 TFRecord 文件。
讀取/反序列化
(1)通過 tf.data.TFRecordDataset 讀入原始的 TFRecord 文件,獲得一個 tf.data.Dataset 的數據集對象。(tf1 中是要創建一個 reader 來讀取 tfrecords 文件中的樣例)
(2)通過 Dataset.map 對數據集對象中每個序列化的 tf.train.Example 字符串執行 tf.io.parse_single_example 實現反序列化。
*:map 過程中,無法在 parse 內部進行某些處理,只能 parse 之后在 dataset 中迭代器“拿”出數據之后進行一些轉換。
tf.train.Feature
上面多次提到的 tf.train.Feature 支持 3 種數據格式,因此對于各種各樣的數據,必須處理成對應這三種格式,才能順利寫入/讀取。
3 種格式如下:
tf.train.BytesList: 字符串或原始 Byte 文件,通過 bytes_list 傳入。以圖片或者數組等類型數據為例,需要轉為字符串類型的再傳入 BytesList,后面會有例子。
tf.train.FloatList:浮點數,通過 float_list 傳入。
tf.train.Int64List:整數,通過 int64_list 傳入。
具體實現
#-*-coding:utf-8-*- import tensorflow as tf from tensorflow.python.platform import gfile import cv2# write def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))img_path = 'path/to/read/img' tfrd_path = 'path/to/save/tfrecords'image = cv2.imread(img_path) h, w, c = image.shape image_raw_data = gfile.FastGFile(img_path, 'rb').read() label = 1tfrd_writer = tf.io.TFRecordWriter(tfrd_path) feature = {'image': _bytes_feature(image_raw_data),'label': _int64_feature(label),'imgH': _int64_feature(h),'imgW': _int64_feature(w),'imgC': _int64_feature(c)} example = tf.train.Example(features=tf.train.Features(feature=feature)) tfrd_writer.write(example.SerializeToString()) tfrd_writer.close()# read raw_dataset = tf.data.TFRecordDataset(tfrd_path) feature_description = {'image': tf.io.FixedLenFeature([], tf.string),'label': tf.io.FixedLenFeature([], tf.int64),'imgH': tf.io.FixedLenFeature([], tf.int64),'imgW': tf.io.FixedLenFeature([], tf.int64),'imgC': tf.io.FixedLenFeature([], tf.int64)}def parse(record):features = tf.io.parse_single_example(record, feature_description)images = tf.io.decode_jpeg(features['image'])labels = features['label']imgH, imgW, imgC = features['imgH'], features['imgW'], features['imgC']shape = [imgH, imgW, imgC]return images, labels, shapedataset = raw_dataset.map(parse)for image, label, shape in dataset:print(label)cv2.imshow('img', image.numpy())cv2.waitKey()如果遇到需要將 numpy 寫入 tfrecord,可以先將 numpy 轉為字符串,然后寫入;讀取的時候再轉為 numpy 即可,注意 dtype 的對應。
import numpy as np gt_row_np = np.array([0, 0, 1, 0], dtype=np.uint8) gt_row_str = gt_row_np.tostring() gt_row = np.frombuffer(gt_row_str, dtype=np.uint8) print('gt_row_np type: {}, dtype: {}, value: {}'.format(type(gt_row_np), gt_row_np.dtype, gt_row_np)) print('gt_row_str type: {}, value: {}'.format(type(gt_row_str), gt_row_str)) print('gt_row type: {}, dtype: {}, value: {}'.format(type(gt_row), gt_row.dtype, gt_row))# output gt_row_np type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0] gt_row_str type: <class 'bytes'>, value: b'\x00\x00\x01\x00' gt_row type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0]總結
以上是生活随笔為你收集整理的TFRecord 的写入和读取(序列化和反序列化)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 免费开源动画制作软件推荐(新手必备)
- 下一篇: Win10删除右键多余选项菜单