Tensorflow 改进的MNIST手写体数字识别
上篇簡單的Tensorflow解決MNIST手寫體數(shù)字識別可擴展性并不好。例如計算前向傳播的函數(shù)需要將所有的變量都傳入,當神經(jīng)網(wǎng)絡的結(jié)構(gòu)變得復雜、參數(shù)更多時,程序的可讀性變得非常差。而且這種方式會導致程序中有大量的冗余代碼。還有就是由于沒有持久化訓練好的模型。當程序退出時,訓練好的模型就無法再使用了,這導致得到的模型無法被重用更嚴重的是神經(jīng)網(wǎng)絡模型的訓練時間都比較長,如果在訓練程序中程序死機了,那樣沒有保存訓練好的中間結(jié)果會浪費大量的時間和資源。所以,在訓練過程中需要每隔一段時間保存一次模型訓練的中間結(jié)果。
下面的代碼將訓練和測試分成兩個獨立的程序,這可以使得每一個組件更加靈活。除了將不同功能模塊分開,本節(jié)還將前向傳播的過程抽象成一個單獨的庫函數(shù)。因為神經(jīng)網(wǎng)絡的前向傳播過程在訓練和測試過程中都會用到,所以通過庫函數(shù)的方式使用起來既可以更加方便,又可以保證訓練和測試過程中使用的前向傳播方法一定是一致的。
下面的代碼是重構(gòu)之后的程序來解決MNIST問題。重構(gòu)之后的代碼會拆分為3個程序。第一個是mnist_inference.py,它定義了前向傳播的過程以及神經(jīng)網(wǎng)絡中的參數(shù)。第二個是mnist_train.py,它定義了神經(jīng)網(wǎng)絡的訓練過程。第三個是mnist_eval.py,它定義了測試過程。
下面的代碼都是由jupyter notebook生成的。
1. mnist_inference.py
# coding: utf-8 #定義了前向傳播的過程和神經(jīng)網(wǎng)絡中的參數(shù) import tensorflow as tf # 1. 定義神經(jīng)網(wǎng)絡結(jié)構(gòu)相關(guān)的參數(shù)。 INPUT_NODE = 784 # 輸入層的節(jié)點數(shù) OUTPUT_NODE = 10# 輸出層的節(jié)點數(shù) LAYER1_NODE = 500 # 隱藏層的節(jié)點數(shù)# #### 2. 通過tf.get_variable函數(shù)來獲取變量。# 通過tf. get_variable函數(shù)來獲取變量:在訓練神經(jīng)網(wǎng)絡時會創(chuàng)建這些變量,在測試時會通過保存的模型保存這些變量的取值。現(xiàn)在更加方便的是由于可以在 # 變量加載時將滑動平均變量重命名,所以可以直接通過同樣的名字在訓練時使用變量本身,而在測試時使用變量的滑動平均值。在這個函數(shù)中也會將變量的 # 正則化損失加入損失函數(shù) def get_weight_variable(shape, regularizer): # 此處的shape為[784x500]weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1)) # 變量初始化函數(shù):tf.truncated_normal_initializer# 當給出了正則化生成函數(shù)時,將當前變量的正則化損失加入名字為losses的集合,在這里使用了add_to_collection函數(shù)將一個張量加入一個集合,而這個# 集合的名稱為losses。這是自定義的集合,不在tensorflow自動管理的集合列表內(nèi)if regularizer != None: tf.add_to_collection('losses', regularizer(weights))return weights# #### 3. 定義神經(jīng)網(wǎng)絡的前向傳播過程。 def inference(input_tensor, regularizer):# 聲明第一層神經(jīng)網(wǎng)絡的變量并完成前向傳播過程with tf.variable_scope('layer1'): # 要通過tf.get_variable獲取一個已經(jīng)創(chuàng)建的變量,需要通過 tf.variable_scope函數(shù)來生成一個上下文管理器。# 這里通過 tf.get_variable和 tf.variable沒有本質(zhì)的區(qū)別,因為在訓練或測試中沒有在同一個程序中多次調(diào)用這個函數(shù)。如果在同一個程序中多次調(diào)用# 在第一次調(diào)用后需要將reuse參數(shù)設置為trueweights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) # 權(quán)重biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0)) # 偏置layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) # tf.nn.relu非線性激活函數(shù)# 類似的聲明第二層神經(jīng)網(wǎng)絡的變量并完成前向傳播過程with tf.variable_scope('layer2'):weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))layer2 = tf.matmul(layer1, weights) + biases# 返回最后前向傳播的結(jié)果return layer2# 在這段代碼中定義了神經(jīng)網(wǎng)絡的前向傳播算法。無論是訓練還是測試,都可以直接調(diào)用此函數(shù),而不用關(guān)心具體的神經(jīng)網(wǎng)絡結(jié)構(gòu)。2. mnist_train.py
# coding: utf-8# #### 使用定義好的前向傳播過程,以下是神經(jīng)網(wǎng)絡的訓練程序 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 加載mnist_inference.py中定義的常量和前向傳播的函數(shù) import mnist_inference import os# #### 1. 定義神經(jīng)網(wǎng)絡結(jié)構(gòu)相關(guān)的參數(shù)。 BATCH_SIZE = 100 # 一個訓練batch中的訓練數(shù)據(jù)個數(shù)。個數(shù)越小越接近隨機梯度下降;數(shù)字越大時,訓練越接近梯度下降 LEARNING_RATE_BASE = 0.8 # 基礎的學習率 LEARNING_RATE_DECAY = 0.99 # 學習率的衰減率 REGULARIZATION_RATE = 0.0001 # 描述模型復雜度的正則化項在損失函數(shù)中的系數(shù) TRAINING_STEPS = 30000# 訓練輪數(shù) MOVING_AVERAGE_DECAY = 0.99 # 滑動平均衰減率 # 模型保存的路徑和文件名 MODEL_SAVE_PATH = "/home/lilong/desktop/ckptt/" MODEL_NAME = "model.ckpt"# #### 2. 定義訓練過程。 def train(mnist):# 定義輸入輸出placeholder(placeholder機制用于提供輸入數(shù)據(jù),該占位符中的數(shù)據(jù)只有在運行時才指定)x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')# 這里使用L2正則化,tf.contrib.layers.l2_regularizer會返回一個函數(shù),這個函數(shù)可以計算一個給定參數(shù)的L2正則化項的值regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)# 直接使用mnist_inference.py中定義的前向傳播函數(shù)y = mnist_inference.inference(x, regularizer) global_step = tf.Variable(0, trainable=False)# 定義損失函數(shù)、學習率、滑動平均操作以及訓練過程。# 定義指數(shù)滑動平均的類,初始化給點了衰減率0.99和控制衰減率的變量global_stepvariable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables()) # 定義一個更新變量滑動平均的操作# 定義交叉熵損失:因為交叉熵一般和softmax回歸一起使用,所以 tf.nn.sparse_softmax_cross_entropy_with_logits函數(shù)對這兩個功能進行了封裝。# 這里使用該函數(shù)進行加速交叉熵的計算,第一個參數(shù)是不包括softmax層的前向傳播結(jié)果。第二個參數(shù)是訓練數(shù)據(jù)的正確答案,這里得到的是正確答案的# 正確編號。cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))# 計算當前batch中所有樣例的交叉熵平均值cross_entropy_mean = tf.reduce_mean(cross_entropy)# 總損失等于交叉熵損失和正則化損失的和loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))# 設置指數(shù)衰減的學習率learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)# 這里使用指數(shù)衰減的學習率。在minimize中傳入global_step將會自動更新global_step參數(shù),從而使學習率得到相應的更新train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)# 在訓練神經(jīng)網(wǎng)絡時,每過一遍數(shù)據(jù)既需要通過反向傳播來更新神經(jīng)神經(jīng)網(wǎng)絡的參數(shù),又需要更新每一個參數(shù)的滑動平均值,這里的 tf.control_dependencieswith tf.control_dependencies([train_step, variables_averages_op]):train_op = tf.no_op(name='train')# 初始化TensorFlow持久化類。saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()# 在訓練過程中不再測試模型在驗證數(shù)據(jù)上的表現(xiàn),驗證和測試的過程將會有一個獨立的程序來完成。for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})# 每1000輪保存一次模型if i % 1000 == 0:# 輸出當前的訓練情況。這里只輸出了模型在當前訓練batch上的損失函數(shù)大小,通過損失函數(shù)的大小可以大概了解訓練的情況。在驗證數(shù)據(jù)數(shù)據(jù)# 上的正確率會有一個單獨的程序來完成。print("After %d training step(s), loss on training batch is %g." % (step, loss_value))# 保存當前的模型。這里給出了global_step參數(shù),這樣可以讓每個被保存模型的文件名末尾加上訓練的輪數(shù)。saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)# #### 3. 主程序入口。 def main(argv=None):# "/home/lilong/desktop/MNIST_data/"# mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True)train(mnist)if __name__ == '__main__':main()運行結(jié)果:
Extracting /home/lilong/desktop/MNIST_data/train-images-idx3-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/train-labels-idx1-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/t10k-images-idx3-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/t10k-labels-idx1-ubyte.gz After 1 training step(s), loss on training batch is 3.12471. After 1001 training step(s), loss on training batch is 0.239917. After 2001 training step(s), loss on training batch is 0.151938. After 3001 training step(s), loss on training batch is 0.135801. After 4001 training step(s), loss on training batch is 0.11508. After 5001 training step(s), loss on training batch is 0.101712. After 6001 training step(s), loss on training batch is 0.096526. After 7001 training step(s), loss on training batch is 0.0867542. After 8001 training step(s), loss on training batch is 0.0778042. After 9001 training step(s), loss on training batch is 0.0693044. After 10001 training step(s), loss on training batch is 0.0648921. After 11001 training step(s), loss on training batch is 0.0598342. After 12001 training step(s), loss on training batch is 0.0602573. After 13001 training step(s), loss on training batch is 0.0580158. After 14001 training step(s), loss on training batch is 0.0491354. After 15001 training step(s), loss on training batch is 0.0492541. After 16001 training step(s), loss on training batch is 0.045001. After 17001 training step(s), loss on training batch is 0.0457389. After 18001 training step(s), loss on training batch is 0.0468493. After 19001 training step(s), loss on training batch is 0.0440138. After 20001 training step(s), loss on training batch is 0.0405837. After 21001 training step(s), loss on training batch is 0.0393501. After 22001 training step(s), loss on training batch is 0.0451467. After 23001 training step(s), loss on training batch is 0.0376411. After 24001 training step(s), loss on training batch is 0.0366882. After 25001 training step(s), loss on training batch is 0.0394025. After 26001 training step(s), loss on training batch is 0.0351238. After 27001 training step(s), loss on training batch is 0.0339706. After 28001 training step(s), loss on training batch is 0.0376363. After 29001 training step(s), loss on training batch is 0.0388179.3. mnist_eval.py
# coding: utf-8import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 加載mnist_inference.py和mnist_train.py中定義的常量和函數(shù) import mnist_inference import mnist_train# #### 1. 每10秒加載一次最新的模型 # 加載的時間間隔:每10秒加載一次新的模型,并在測試數(shù)據(jù)上測試最新模型的正確率 EVAL_INTERVAL_SECS = 10def evaluate(mnist):with tf.Graph().as_default() as g:# 定義輸入輸出的格式x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}# 直接通過調(diào)用封裝好的函數(shù)來計算前向傳播結(jié)果。因為測試時不關(guān)注正則化的值,所以這里用于計算正則化損失的函數(shù)被設置為noney = mnist_inference.inference(x, None)# 使用前向傳播的結(jié)果計算正確率。如果需要對未來的樣例進行分類,使用tf.argmax()就可以得到輸入樣例的預測類別了correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 通過變量重命名的方式來加載模型,這樣在前向傳播的過程中就不需要調(diào)用求滑動平均的函數(shù)來獲取平均值了。這樣就可以完全共用mnist_inference.py# 中定義的前向傳播過程variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# 每隔10秒調(diào)用一次計算正確率的過程以檢測訓練過程中正確率的變化while True:with tf.Session() as sess:# tf.train.get_checkpoint_state函數(shù)會通過checkpoint文件自找到目錄中最新的文件名ckpt = tf.train.get_checkpoint_state("/home/lilong/desktop/ckptt/") # ckpt.model_checkpoint_path:表示模型存儲的位置,不需要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫做什么。if ckpt and ckpt.model_checkpoint_path:# 加載模型saver.restore(sess, ckpt.model_checkpoint_path)# 通過文件名得到模型保存時迭代的輪數(shù)(split('/')[-1].split('-')[-1]:正則表達式)global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict=validate_feed)print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))else:print('No checkpoint file found')returntime.sleep(EVAL_INTERVAL_SECS)# ### 主程序def main(argv=None):mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True)evaluate(mnist)if __name__ == '__main__':main()本測試代碼會每隔10秒運行一次,每次運行都是讀取最新保存的模型。并在MNIST驗證數(shù)據(jù)集上計算模型的正確率。注意這里如果運行完訓練程序后再單獨運行該測試程序會得到如下的運行結(jié)果:
Extracting /home/lilong/desktop/MNIST_data/train-images-idx3-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/train-labels-idx1-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/t10k-images-idx3-ubyte.gz Extracting /home/lilong/desktop/MNIST_data/t10k-labels-idx1-ubyte.gz INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001 After 29001 training step(s), validation accuracy = 0.9846 INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001 After 29001 training step(s), validation accuracy = 0.9846 INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001 After 29001 training step(s), validation accuracy = 0.9846 INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001 After 29001 training step(s), validation accuracy = 0.9846 INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001 After 29001 training step(s), validation accuracy = 0.9846--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-15-c2f081a58572> in <module>()5 6 if __name__ == '__main__': ----> 7 main()<ipython-input-15-c2f081a58572> in main(argv)2 # mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)3 mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True) ----> 4 evaluate(mnist)5 6 if __name__ == '__main__':<ipython-input-14-2c48dfb7e249> in evaluate(mnist)35 print('No checkpoint file found')36 return ---> 37 time.sleep(EVAL_INTERVAL_SECS)KeyboardInterrupt:從運行結(jié)果結(jié)果看出:最新模型始終是同一個,所以這里是離線的測試,要想達到在線的效果應該在運行mnist_train.py的同時也運行mnist_eval.py。但是這里必須等到產(chǎn)生訓練模型后再開始運行測試程序,否則會輸出提示:No checkpoint file found。
在線運行的效果如下:
訓練模型的過程:
與此同時測試過程:
本示例中最關(guān)鍵的就是:
# 通過變量重命名的方式來加載模型,這樣在前向傳播的過程中就不需要調(diào)用求滑動平均的函數(shù)來獲取平均值了。# 這樣就可以完全共用mnist_inference.py中定義的前向傳播過程,這里是關(guān)鍵。 variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)這里才是為什么可以把訓練和測試分開的原因,關(guān)于變量重命名、模型保存、重載可以參考:https://blog.csdn.net/lilong117194/article/details/81742536
《Tensorflow實戰(zhàn)Google深度學習框架》-——5.5 最佳實踐樣例程序
總結(jié)
以上是生活随笔為你收集整理的Tensorflow 改进的MNIST手写体数字识别的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 征信被拉黑如何恢复
- 下一篇: CNN基础知识(2)