简单的线性模型实现tensorflow权重的生成和调用,并且用类的方式实现参数共享
生活随笔
收集整理的這篇文章主要介紹了
简单的线性模型实现tensorflow权重的生成和调用,并且用类的方式实现参数共享
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
首先看文件路徑,line_regression是總文件夾,model文件夾存放權重文件,
global_variable.py寫了一句話.
?
save_path='./model/weight'權重要存放的路徑,以weight命名.
lineRegulation_model.py代碼
?
import tensorflow as tf """ 類定義一些公共量,方便模型載入用 """ class LineRegModel:def __init__(self):self.a_val=tf.Variable(tf.random_normal(shape=[1]))self.b_val = tf.Variable(tf.random_normal(shape=[1]))self.x_input=tf.placeholder(dtype=tf.float32)self.y_label = tf.placeholder(dtype=tf.float32)self.y_output = tf.multiply(self.x_input,self.a_val)+self.b_valself.loss=tf.reduce_mean(tf.pow(self.y_output-self.y_label,2))def get_op(self):return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)定義了一個類,方便后面共享權值恢復模型的調用
model_train.py代碼:
?
import tensorflow as tf import numpy as np from save_and_restore import global_variable from save_and_restore import lineRegulation_model as model """ 訓練模型 """ train_x=np.random.rand(5) train_y=train_x*5+3 model=model.LineRegModel()#類要加括號 a_val=model.a_val b_val=model.b_val x_input=model.x_input y_label=model.y_label y_output=model.y_output loss=model.loss optimizer=model.get_op() if __name__ == '__main__':saver = tf.train.Saver()init=tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)flag=Trueepoch=0while flag:epoch+=1cost,_=sess.run([loss,optimizer],feed_dict={x_input:train_x,y_label:train_y})if cost<1e-6:flag=Falseprint('a={},b={}'.format(a_val.eval(sess),b_val.eval(sess)))print('epoch={}'.format(epoch))saver.save(sess,global_variable.save_path)print('model save finish')訓練模型,并且存放模型的目的,這樣前面三段代碼就可以實現簡單的線性模型權重的生成和存放。
其中checkpoint指的是檢查點文件,記錄存儲文件名稱,weight.data_00000-of-00001權重存儲文件,weight.index存儲權重目錄
weight.meta模型的全部圖文件,所以weight.data_00000-of-00001和weight.meta是最大的。
model_restore.py代碼如下:
import tensorflow as tf from save_and_restore import global_variable,lineRegulation_model as model """ 加載模型 """ model=model.LineRegModel() x_input=model.x_input y_output=model.y_output init=tf.global_variables_initializer() saver=tf.train.Saver() with tf.Session() as sess:sess.run(init)saver.restore(sess,global_variable.save_path)result=sess.run(y_output,feed_dict={x_input:[1]})print(result)調用生成的模型打印出預測結果:
結果和8差不多。
?
總結
以上是生活随笔為你收集整理的简单的线性模型实现tensorflow权重的生成和调用,并且用类的方式实现参数共享的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 百度地图之添加覆盖物
- 下一篇: SVM原理与实战