tensorflow学习笔记————分类MNIST数据集
生活随笔
收集整理的這篇文章主要介紹了
tensorflow学习笔记————分类MNIST数据集
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
?
在使用tensorflow分類MNIST數據集中,最容易遇到的問題是下載MNIST樣本的問題。
?
一般是通過使用tensorflow內置的函數進行下載和加載,
?
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
?
但是我使用時遇到了“urllib.error.URLError: <urlopen error [Errno 99] Cannot assign requested address>”錯誤,查了一下也沒什么好的解決方案,最后就自己去手動下載了。在python文件同目錄下建立MNIST_data,進入目錄后通過wget來下載
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
?
最后運行我們的程序
1 import tensorflow as tf 2 from tensorflow.examples.tutorials.mnist import input_data 3 4 #通過tensorflow的庫來載入訓練的樣本 5 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 6 7 #每個批次的大小 8 batch_size = 100 9 10 #計算有多少批次 11 n_batch = mnist.train.num_examples // batch_size 12 13 #定義兩個placeholder,x是圖片樣本,y是輸出的結果 14 x = tf.placeholder(tf.float32, [None,784]) 15 y = tf.placeholder(tf.float32, [None,10]) 16 17 #創建一個簡單的神經網絡 18 W = tf.Variable(tf.zeros([784,10])) 19 b = tf.Variable(tf.zeros([10])) 20 prediction = tf.nn.softmax(tf.matmul(x,W)+b) 21 22 #二次代價函數 23 loss = tf.reduce_mean(tf.square(y - prediction)) 24 25 #使用梯度下降法 26 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) 27 28 #初始化變量 29 init = tf.global_variables_initializer() 30 31 #結果存放在一個布爾類型列表中, tf.argmax返回一維張量中最大的值所在的位置,就是返回識別出來最可能的結果 32 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1)) 33 34 #求準確率,tf.case()把bool轉化為float 35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 36 37 with tf.Session() as sess: 38 sess.run(init) 39 for epoch in range(21): 40 for batch in range(n_batch): 41 batch_xs,batch_ys = mnist.train.next_batch(batch_size) 42 sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys}) 43 44 acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}) 45 print("Iter " + str(epoch) + ", Testing Accuracy" + str(acc)) 46
?
轉載于:https://www.cnblogs.com/QKSword/p/8723677.html
總結
以上是生活随笔為你收集整理的tensorflow学习笔记————分类MNIST数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 唐卡可以去汉族的佛堂开光加持吗
- 下一篇: [CQOI2014]数三角形 组合