TensorFlow(2)-训练数据载入
tensorflow 訓(xùn)練數(shù)據(jù)載入
- 1. tf.data.Dataset
- 2. dataset 創(chuàng)建數(shù)據(jù)集的方式
- 2.1 tf.data.Dataset.from_tensor_slices()
- 2.2 tf.data.TextLineDataset()
- 2.3 tf.data.FixedLengthRecordDataset()
- 2.4 tf.data.TFRecordDataset()
- 3. dateset 迭代操作iterator
- 3.1 make_one_shot_iterator()
- 3.2 make_initializable_iterator()
- 3.3 reinitializable iterator()
- 3.4 feedable iterator()
- 4. dataset的map、batch、shuffle、repeat操作
- 5. 非eager/eager 模式
- 5.1 非eager模式demo
- 5.2 eager模式demo
1. tf.data.Dataset
參考Google官方給出的Dataset API中的類圖,Dataset 務(wù)于數(shù)據(jù)讀取,構(gòu)建輸入數(shù)據(jù)的pipeline。
Dataset可以看作是相同類型“元素”的有序列表,可使用Iterator迭代獲取Dataset中的元素。
2. dataset 創(chuàng)建數(shù)據(jù)集的方式
2.1 tf.data.Dataset.from_tensor_slices()
從tensor中創(chuàng)建數(shù)據(jù)集,數(shù)據(jù)集元素以tensor第一維度為劃分。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個(gè)維度,生成相應(yīng)的dataset。 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) # 如果傳入字典,那切分結(jié)果就是字典按值切分,元素型如{"a":[1],"b":[x,x]} dataset2 = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "b": np.random.uniform(size=(5, 2))} )2.2 tf.data.TextLineDataset()
讀取文件數(shù)據(jù)創(chuàng)建數(shù)據(jù)集,數(shù)據(jù)集元素為文件的每一行
2.3 tf.data.FixedLengthRecordDataset()
從一個(gè)文件列表和record_bytes中創(chuàng)建數(shù)據(jù)集,數(shù)據(jù)集元素是文件中固定字節(jié)數(shù)record_bytes的內(nèi)容。
2.4 tf.data.TFRecordDataset()
讀TFRecord文件創(chuàng)建數(shù)據(jù)集,數(shù)據(jù)集中的一條數(shù)據(jù)是一個(gè)TFExample。
dataset = tf.data.TFRecordDataset(filenames = [tfrecord_file_name]) # [tfrecord_file_name] tfrecord 文件列表
frecord 文件中的特征一般都經(jīng)過tf.train.Example 序列化,在使用前需要先解碼tf.train.Example.FromString()
raw_example = next(iter(dataset)) parsed = tf.train.Example.FromString(raw_example.numpy())3. dateset 迭代操作iterator
iterator是從Dataset對(duì)象中創(chuàng)建出來的,用于迭代取數(shù)據(jù)集中的元素。
3.1 make_one_shot_iterator()
dataset.make_one_shot_iterator()–只能從頭到尾讀取一次dataset。如果一個(gè)dataset中元素被讀取完了再sess.run()的話,會(huì)拋出tf.errors.OutOfRangeError異常。因此可以在外界捕捉這個(gè)異常以判斷數(shù)據(jù)是否讀取完。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個(gè)維度,生成相應(yīng)的dataset。如果傳入字典,那切分結(jié)果就是字典按值切分 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) iterator = dataset.make_one_shot_iterator() # 只能從頭到尾讀取一次 one_element = iterator.get_next() # 從iterator里取出一個(gè)元素。 # 處于非Eager模式,所以one_element只是一個(gè)Tensor,并不是一個(gè)實(shí)際的值。調(diào)用sess.run(one_element)后,才能真正地取出一個(gè)值。 with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")3.2 make_initializable_iterator()
dataset.make_initializable_iterator()–支持placeholder dataset 的迭代操作,這可以方便通過參數(shù)快速定義新的Iterator。
# limit相當(dāng)于一個(gè)參數(shù),它規(guī)定了Dataset中數(shù)的上限, 使用make_initializable_iterator limit = tf.placeholder(dtype=tf.int32, shape=[]) dataset = tf.data.Dataset.from_tensor_slices(tf.range(start=0, limit=limit)) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() with tf.Session() as sess:sess.run(iterator.initializer, feed_dict={limit: 10})for i in range(10):value = sess.run(next_element)assert i == valuesess.run(next_element) 每run一次, 數(shù)據(jù)迭代器指針就會(huì)往下移動(dòng)一個(gè)。TF官網(wǎng)學(xué)習(xí)(9)–使用iterator注意事項(xiàng)
如果在dataset的構(gòu)建時(shí),一次性讀入了所有的數(shù)據(jù),會(huì)導(dǎo)致計(jì)算圖變得很大,給傳輸、保存帶來不便。make_initializable_iterator()支持placeholder 操作,僅在需要傳輸數(shù)據(jù)時(shí)再取數(shù)據(jù)。
# 從硬盤中讀入兩個(gè)Numpy數(shù)組 with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]features_placeholder = tf.placeholder(features.dtype, features.shape) labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) iterator = dataset.make_initializable_iterator() sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})3.3 reinitializable iterator()
dataset.reinitializable iterator() --待補(bǔ)
3.4 feedable iterator()
dataset.feedable iterator()–待補(bǔ)
4. dataset的map、batch、shuffle、repeat操作
map–接收一個(gè)函數(shù),Dataset中的每個(gè)元素都會(huì)被當(dāng)作這個(gè)函數(shù)的輸入,并將函數(shù)返回值作為新的Dataset。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0batch–將多個(gè)元素組合成一個(gè)batch
dataset = dataset.batch(16) # 將數(shù)據(jù)集劃分為batch size為16的小批次shuffle– 打亂dataset中的元素,參數(shù)buffersize。打亂的實(shí)現(xiàn)機(jī)理:從buffer_size 大小的部buffer中隨機(jī)抽取元素,組成打亂后的數(shù)據(jù)集。buffer中被抽走的元素由原數(shù)據(jù)集中的后續(xù)元素補(bǔ)位置。 重復(fù)‘抽取-補(bǔ)充’這個(gè)過程,直至buffer為空。
會(huì)在batch之間打亂數(shù)據(jù)–疑問多tfrecord 文件是一次性構(gòu)建數(shù)據(jù)集還是一條一條的構(gòu)建
buffer_size 的大小詳見tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解
dataset = dataset.shuffle(buffer_size=10000)repeat– 將整個(gè)序列重復(fù)多次,用來處理機(jī)器學(xué)習(xí)中的epoch,假設(shè)原始數(shù)據(jù)是一個(gè)epoch,使用repeat(5)就可以將之變成5個(gè)epoch
dataset = dataset.repeat(5)5. 非eager/eager 模式
5.1 非eager模式demo
在非Eager模式下,Dataset中讀出的一個(gè)元素一般對(duì)應(yīng)一個(gè)batch的Tensor,我們可以使用這個(gè)Tensor在計(jì)算圖中構(gòu)建模型。
import tensorflow as tf import numpy as np # 切分傳入Tensor的第一個(gè)維度,生成相應(yīng)的dataset。如果傳入字典,那切分結(jié)果就是字典按值切分 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) iterator = dataset.make_one_shot_iterator() # 只能從頭到尾讀取一次 one_element = iterator.get_next() # 從iterator里取出一個(gè)元素。 # 處于非Eager模式,所以one_element只是一個(gè)Tensor,并不是一個(gè)實(shí)際的值。調(diào)用sess.run(one_element)后,才能真正地取出一個(gè)值。 with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")5.2 eager模式demo
在Eager模式下,Dataset建立Iterator的方式有所不同,此時(shí)通過讀出的數(shù)據(jù)就是含有值的Tensor,方便調(diào)試。
import tensorflow.contrib.eager as tfe tfe.enable_eager_execution() dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) for one_element in tfe.Iterator(dataset):print(one_element) # 可直接讀取數(shù)據(jù)參考文獻(xiàn):TensorFlow全新的數(shù)據(jù)讀取方式:Dataset API入門教程
總結(jié)
以上是生活随笔為你收集整理的TensorFlow(2)-训练数据载入的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ThinkPHP redirect 页面
- 下一篇: 在GCC和Visual Studio中使