TensorFlow提供了一種TFRecords的格式來統一存儲數據。理論上,TFRecords可以存儲任何形式的數據 , TFRecords文件的是以tf.train.Example Protocol Buffer的格式存儲的。以下的代碼給出了tf.train.Example的數據結構:
message Example {Features features =
1;};message Features {
map<string, Feature> feature =
1;};message Feature {oneof kind {BytesList bytes_list =
1;FloatList float_list =
2;Int64List int64_list =
3;}};
首先介紹一下我接下來要展示給大家的工程結構(使用的IDE是 pycharm 2017 community):
接下來代碼分三個文件, 分別是 加載數據 prepare_data.py ,制作tfrecords文件 make_data.py, 讀取tfrecords文件read_data.py。
1.prepare_data.py
下面代碼中數據增強部分我就略過了,可以參考tensorflow數據增強
import os
import cv2
import numpy
as npdir =
"imgs/" def data_augmentation(data):"""數據增強處理:param data::return: """return data
def get_img_data(file_dir):"""獲取圖片數據, 返回類型是 list:param file_dir: 圖片所在目錄:return: 返回類型是 list"""files = [os.path.join(
'imgs', x)
for x
in os.listdir(file_dir)]raw_data = [cv2.imread(img)
for img
in files]raw_data = data_augmentation(raw_data)
return raw_data
if __name__ ==
"__main__":get_img_data(dir)
make_data.py
# _*_ coding: utf-8 _*_import tensorflow as tf
import numpy as npfrom prepare_data
import get_img_data# tfrecords 支持的數據類型
# tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
# tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
# tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))# 創建tfrecords文件
file_nums =
2
instance_per_file =
5
dir =
"imgs/"data = get_img_data(dir) # type(data) list
for i
in range(file_nums):tfrecords_filename = './tfrecords/train.tfrecords-%
.5d-
of-%
.5d' % (i, file_nums)writer = tf.python_io.
TFRecordWriter(tfrecords_filename) # 創建.tfrecord文件for j
in range(instance_per_file):#
type(data[i*instance_per_file+j]) numpy.ndarrayimg_raw = np.asarray(
data[i*instance_per_file+j]).tostring()example = tf.train.
Example(features=tf.train.
Features(feature={'label': tf.train.
Feature(int64_list=tf.train.
Int64List(value=[j])),'img_raw': tf.train.
Feature(bytes_list=tf.train.
BytesList(value=[img_raw]))}))writer.write(example.
SerializeToString())writer.close()
read_data.py
import tensorflow
as tf
import numpy
as np
import cv2
import matplotlib.pyplot
as plt
batch_size =
2
capacity =
1000 +
3*batch_size
train_rounds =
3
num_epochs =
30
img_h =
333
img_w =
500
tfrecord_files = tf.train.match_filenames_once(
'./tfrecords/train.tfrecords-*')
queue = tf.train.string_input_producer(tfrecord_files, num_epochs=num_epochs, shuffle=
True, capacity=
10)reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)
features = tf.parse_single_example(serialized_example, features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),}
)image = tf.decode_raw(features[
'img_raw'], tf.uint8)
image = tf.reshape(image, [img_h, img_w,
3])
label = tf.cast(features[
'label'], tf.int64)
to_train_batch, to_label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size, capacity=capacity,allow_smaller_final_batch=
True, num_threads=
1, min_after_dequeue=
1
)
with tf.Session()
as sess:sess.run(tf.group(tf.local_variables_initializer(), tf.global_variables_initializer()))coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i
in range(train_rounds):train_batch, label_batch = sess.run([to_train_batch, to_label_batch])plt.subplot(
121)plt.imshow(train_batch[
0])plt.subplot(
122)plt.imshow(train_batch[
1])plt.show()coord.request_stop()coord.join(threads)print(
'finish all')
總結
以上是生活随笔為你收集整理的tensorflow 标准数据读取 tfrecords的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。