日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

预训练模型参数重载必备!

發布時間:2025/4/16 编程问答 18 豆豆
生活随笔 收集整理的這篇文章主要介紹了 预训练模型参数重载必备! 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

版權聲明:本文為博主原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。
本文鏈接:https://blog.csdn.net/ying86615791/article/details/76215363

之前已經寫了一篇《Tensorflow保存模型,恢復模型,使用訓練好的模型進行預測和提取中間輸出(特征)》,里面主要講恢復模型然后使用該模型

假如要保存或者恢復指定tensor,并且把保存的graph恢復(插入)到當前的graph中呢?

總的來說,目前我會的是兩種方法,命名都是很關鍵!
兩種方式保存模型,
1.保存所有tensor,即整張圖的所有變量,
2.只保存指定scope的變量
兩種方式恢復模型,
1.導入模型的graph,用該graph的saver來restore變量
2.在新的代碼段中寫好同樣的模型(變量名稱及scope的name要對應),用默認的graph的saver來restore指定scope的變量


兩種保存方式:
1.保存整張圖,所有變量

...init = tf.global_variables_initializer()saver = tf.train.Saver()config = tf.ConfigProto()config.gpu_options.allow_growth=Truewith tf.Session(config=config) as sess:sess.run(init)...writer.add_graph(sess.graph)...saved_path = saver.save(sess,saver_path)...

?2.保存圖中的部分變量
???

...init = tf.global_variables_initializer()vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')#獲取指定scope的tensorsaver = tf.train.Saver(vgg_ref_vars)#初始化saver時,傳入一個var_list的參數config = tf.ConfigProto()config.gpu_options.allow_growth=Truewith tf.Session(config=config) as sess:sess.run(init)...writer.add_graph(sess.graph)...saved_path = saver.save(sess,saver_path)...


兩種恢復方式:
1.導入graph來恢復

??

...vgg_meta_path = params['vgg_meta_path'] # 后綴是'.ckpt.meta'的文件vgg_graph_weight = params['vgg_graph_weight'] # 后綴是'.ckpt'的文件,里面是各個tensor的值saver_vgg = tf.train.import_meta_graph(vgg_meta_path) # 導入graph到當前的默認graph中,返回導入graph的saverx_vgg_feat = tf.get_collection('inputs_vgg')[0] #placeholder, [None, 4096],獲取輸入的placeholderfeat_decode = tf.get_collection('feat_encode')[0] #[None, 1024],獲取要使用的tensor"""以上兩個獲取tensor的方式也可以為:graph = tf.get_default_graph()centers = graph.get_tensor_by_name('loss/intra/center_loss/centers:0')當然,前提是有tensor的名字"""...init = tf.global_variables_initializer()saver = tf.train.Saver() # 這個是當前新圖的saverconfig = tf.ConfigProto()config.gpu_options.allow_growth=Truewith tf.Session(config=config) as sess:sess.run(init)...saver_vgg.restore(sess, vgg_graph_weight)#使用導入圖的saver來恢復...


2.重寫一樣的graph,然后恢復指定scope的變量

???

def re_build():#重建保存的那個graphwith tf.variable_scope('vgg_feat_fc'): #沒錯,這個scope要和需要恢復模型中的scope對應...return ......


??? vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc') # 既然有這個scope,其實第1個方法中,導入graph后,可以不用返回的vgg_saver,再新建一個指定var_list的vgg_saver就好了,恩,需要傳入一個var_list的參數
???

...init = tf.global_variables_initializer()saver_vgg = tf.train.Saver(vgg_ref_vars) # 這個是要恢復部分的saversaver = tf.train.Saver() # 這個是當前新圖的saverconfig = tf.ConfigProto()config.gpu_options.allow_growth=Truewith tf.Session(config=config) as sess:sess.run(init)...saver_vgg.restore(sess, vgg_graph_weight)#使用導入圖的saver來恢復...


總結一下,這里的要點就是,在restore的時候,saver要和模型對應,如果直接用當前graph的saver = tf.train.Saver(),來恢復保存模型的權重saver.restore(vgg_graph_weight),就會報錯,提示key/tensor ... not found之類的錯誤;
寫graph的時候,一定要注意寫好scope和tensor的name,合理插入variable_scope;

最方便的方式還是,用第1種方式來保存模型,這樣就不用重寫代碼段了,然后第1種方式恢復,不過為了穩妥,最好還是通過獲取var_list,指定saver的var_list,妥妥的!


最新發現,用第1種方式恢復時,要記得當前的graph和保存的模型中沒有重名的tensor,否則當前graph的tensor name可能不是那個name,可能在后面加了"_1"....-_-||


在恢復圖基礎上構建新的網絡(變量)并訓練(finetuning)(2017.11.9更新)

恢復模型graph和weights在上面已經說了,這里的關鍵點是怎樣只恢復原圖的權重 ,并且使optimizer只更新新構造變量(指定層、變量)。

(以下code與上面沒聯系)

?????

"""1.Get input, output , saver and graph"""#從導入圖中獲取需要的東西meta_path_restore = model_dir + '/model_'+model_version+'.ckpt.meta'model_path_restore = model_dir + '/model_'+model_version+'.ckpt'saver_restore = tf.train.import_meta_graph(meta_path_restore) #獲取導入圖的saver,便于后面恢復graph_restore = tf.get_default_graph() #此時默認圖就是導入的圖#從導入圖中獲取需要的tensor#1. 用collection來獲取input_x = tf.get_collection('inputs')[0]input_is_training = tf.get_collection('is_training')[0]output_feat_fused = tf.get_collection('feat_fused')[0]#2. 用tensor的name來獲取input_y = graph_restore.get_tensor_by_name('label_exp:0')print('Get tensors...')print('inputs shape: {}'.format(input_x.get_shape().as_list()))print('input_is_training shape: {}'.format(input_is_training.get_shape().as_list()))print('output_feat_fused shape: {}'.format(output_feat_fused.get_shape().as_list()))"""2.Build new variable for fine tuning"""#構造新的variables用于后面的finetuninggraph_restore.clear_collection('feat_fused') #刪除以前的集合,假如finetuning后用新的代替原來的graph_restore.clear_collection('prob')#添加新的東西if F_scale is not None and F_scale!=0:print('F_scale is not None, value={}'.format(F_scale))feat_fused = Net_normlize_scale(output_feat_fused, F_scale)tf.add_to_collection('feat_fused',feat_fused)#重新添加到新集合logits_fused = last_logits(feat_fused,input_is_training,7) # scope name是"final_logits""""3.Get acc and loss"""#構造損失with tf.variable_scope('accuracy'):accuracy,prediction = ...with tf.variable_scope('loss'):loss = ..."""4.Build op for fine tuning"""global_step = tf.Variable(0, trainable=False,name='global_step')learning_rate = tf.train.exponential_decay(initial_lr,global_step=global_step,decay_steps=decay_steps,staircase=True,decay_rate=0.1)update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):var_list = tf.contrib.framework.get_variables('final_logits')#關鍵!獲取指定scope下的變量train_op = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9).minimize(loss,global_step=global_step,var_list=var_list) #只更新指定的variables"""5.Begin training"""init = tf.global_variables_initializer()saver = tf.train.Saver()config = tf.ConfigProto()config.gpu_options.allow_growth=Truewith tf.Session(config=config) as sess:sess.run(init)saver_restore.restore(sess, model_path_restore) #這里saver_restore對應導入圖的saver, 如果用上面新的saver的話會報錯 因為多出了一些新的變量 在保存的模型中是沒有那些權值的sess.run(train_op, feed_dict).......

?

再說明下兩個關鍵點:

1. 如何在新圖的基礎上 只恢復 導入圖的權重 ?

用導入圖的saver: saver_restore

2. 如何只更新指定參數?

用var_list = tf.contrib.framework.get_variables(scope_name)獲取指定scope_name下的變量,

然后optimizer.minimize()時傳入指定var_list

?

附:如何知道tensor名字以及獲取指定變量?

1.獲取某個操作之后的輸出

用graph.get_operations()獲取所有op

比如<tf.Operation 'common_conv_xxx_net/common_conv_net/flatten/Reshape' type=Reshape>,

那么output_pool_flatten = graph_restore.get_tensor_by_name('common_conv_xxx_net/common_conv_net/flatten/Reshape:0')就是那個位置經過flatten后的輸出了

2.獲取指定的var的值

用GraphKeys獲取變量

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)返回指定集合的變量

比如 <tf.Variable 'common_conv_xxx_net/final_logits/logits/biases:0' shape=(7,) dtype=float32_ref>

那么var_logits_biases = graph_restore.get_tensor_by_name('common_conv_xxx_net/final_logits/logits/biases:0')就是那個位置的biases了

3.獲取指定scope的collection

tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,scope='common_conv_xxx_net.final_logits')

注意后面的scope是xxx.xxx不是xxx/xxx


————————————————
版權聲明:本文為CSDN博主「美利堅節度使」的原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/ying86615791/article/details/76215363

總結

以上是生活随笔為你收集整理的预训练模型参数重载必备!的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。