【Ubuntu-Tensorflow】TF1.0到TF1.2出现“Key LSTM/basic_lstm_cell/bias not found in checkpoin”问题
問題詳情:
Caused by op u’save/RestoreV2’, defined at:
File “demo.py”, line 25, in
result_dict = news_demo.newsAggreg({image_path})
File “/home/rszj/liutao/news_aggreg/news_demo.py”, line 32, in newsAggreg
predict = news_predict.run(images_path)
File “/home/rszj/liutao/news_aggreg/news_predict.py”, line 179, in run
saver = tf.train.Saver(restore_dict) # when you want to save model
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1139, in init
self.build()
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1170, in build
restore_sequentially=self._restore_sequentially)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 691, in build
restore_sequentially, reshape)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 407, in _AddRestoreOps
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 247, in restore_op
[spec.tensor.dtype])[0])
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py”, line 640, in restore_v2
dtypes=dtypes, name=name)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”, line 767, in apply_op
op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 1269, in init
self._traceback = _extract_stack()
NotFoundError (see above for traceback): Key LSTM/basic_lstm_cell/bias not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT], _device=”/job:localhost/replica:0/task:0/cpu:0”](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
[[Node: save/RestoreV2_26/_101 = _Recvclient_terminated=false, recv_device=”/job:localhost/replica:0/task:0/gpu:0”, send_device=”/job:localhost/replica:0/task:0/cpu:0”, send_device_incarnation=1, tensor_name=”edge_212_save/RestoreV2_26”, tensor_type=DT_FLOAT, _device=”/job:localhost/replica:0/task:0/gpu:0”]]
參考一:參考下面這篇博客進行解決
tensorflow1.x版本加載saver.restore目錄報錯
在 ubuntu源代碼 news_predict.py中
saver = tf.train.Saver(restore_dict) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) saver.restore(sess, r'model/model.ckpt')TF1.0正確運行結果如下
ubuntu中修改上述代碼如下:
saver = tf.train.Saver(restore_dict) init = tf.global_variables_initializer() savsess = tf.Session() sess.run(init) module_file = tf.train.latest_checkpoint('news_tf_model/model.ckpt') #ckpt路徑抽調出來 if module_file is not None: # 添加一個判斷語句,判斷ckpt的路徑文件saver.restore(sess, module_file)TF1.0和TF1.2運行結果全部是下面情況
參考一分析:
“if module_file is not None”該判斷僅僅是做了一個文件是否存在的判斷,并沒有從根本上解決LSTM的報錯問題,而代碼不執(zhí)行“ saver.restore(sess, module_file)”,就造成最后得到的結果為空了。
問題分析:
其實,仔細查看提示,會發(fā)現(xiàn),報錯的是指出“Key LSTM/basic_lstm_cell/bias not found in checkpoint”,那必然是LSTM中的bias定義出現(xiàn)了問題。所以筆者打印了saver = tf.train.Saver(restore_dict)中的“restore_dict”,發(fā)現(xiàn)TF1.0中和TF1.2中參數(shù)存在差異如下表
| lstm/basic_lstm_cell/weights | lstm/basic_lstm_cell/kernel |
| lstm/basic_lstm_cell/biases | lstm/basic_lstm_cell/bias |
原來問題確實出現(xiàn)在了LSTM上了,TF1.0和TF1.2的LSTM竟然在命名上出現(xiàn)了差異,好吧,看來要在TF1.2上使用TF1.0訓練好的ckpt模型,必須要對應LSTM的上面兩個參數(shù)了。
要對應參數(shù)其實有兩種辦法,第一種,修改ckpt模型中LSTM兩個變量名;第二種,在predict時,做符合TF版本的LSTM變量名的對應。
參考二:
接下來,先介紹第一種方法,根據(jù) 基于tensorflow 1.0的圖像敘事功能測試(model/im2txt) 這篇博客的內容,修改代碼如下
# 由于版本不同,需要進行修改 def RenameCkpt(): # 1.0.1 : 1.2.1vars_to_rename = {"lstm/basic_lstm_cell/weights": "lstm/basic_lstm_cell/kernel","lstm/basic_lstm_cell/biases": "lstm/basic_lstm_cell/bias",}new_checkpoint_vars = {}reader = tf.train.NewCheckpointReader(FLAGS.checkpoint_path)for old_name in reader.get_variable_to_shape_map():if old_name in vars_to_rename:new_name = vars_to_rename[old_name]else:new_name = old_namenew_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))init = tf.global_variables_initializer()saver = tf.train.Saver(new_checkpoint_vars)with tf.Session() as sess:sess.run(init)saver.save(sess, "/home/ndscbigdata/work/change/tf/gan/im2txt/ckpt/newmodel.ckpt-2000000")print("checkpoint file rename successful... ")上述方法是修改ckpt模型中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias,修改完成后的ckpt僅僅能夠在1.2.1上正常運行,同樣因為參數(shù)名修改了變得版本不對應,而無法在1.0.1上運行。
參考三:
根據(jù)以上描述,筆者想到了方法二,按照正常邏輯,修改restore_dict中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias
restore_dict = {} for i in variables[:]: # the first is global step#restore_dict[i.name.replace(':0', '')] = iif i.name.replace(':0', '')=='LSTM/basic_lstm_cell/biases':print('LSTM/basic_lstm_cell/bias========================================')restore_dict[i.name.replace('LSTM/basic_lstm_cell/biases:0','LSTM/basic_lstm_cell/bias')] = tf.get_variable('LSTM/basic_lstm_cell/bias',[2048,])elif i.name.replace(':0', '')=='LSTM/basic_lstm_cell/weights':print('LSTM/basic_lstm_cell/kernel========================================')restore_dict[i.name.replace('LSTM/basic_lstm_cell/weights:0','LSTM/basic_lstm_cell/kernel')] = tf.get_variable('LSTM/basic_lstm_cell/kernel',[1536, 2048])else:restore_dict[i.name.replace(':0', '')] = i原本以為可以兼容1.0.1和1.2.1版本了,但是出現(xiàn)一個問題,對同一張圖片分別在tf1.0.1和tf1.2.1兩個版本下進行多標簽預測,見如下兩圖
圖1—-tf1.2.1環(huán)境下運行結果(這是正確的結果)
圖2 —-tf1.0.1環(huán)境下運行結果(發(fā)現(xiàn)只能預測第二個標簽,第一個丟失了)
至于為何丟失的問題,我在做測試中,發(fā)現(xiàn),盡管修改了對應于當前tf版本的 “l(fā)stm/basic_lstm_cell/weights” 和”lstm/basic_lstm_cell/biases”,但是并沒有起到作用,這個可以通過注釋下面兩行代碼運行程序,發(fā)現(xiàn)也是上述結果
總結:
總的來說,1.0.1和1.2.1在使用saver的時候,存在著ckpt模型參數(shù)和saver初始化restore_dict中的參數(shù)的一一對應的情況,其中以LSTM中的兩個參數(shù):lstm/basic_lstm_cell/weights 和 lstm/basic_lstm_cell/biases 容易出現(xiàn)因為版本的不同,ckpt與預測代碼中自定義的restore_dict中兩個參數(shù)不匹配的情況,就會報出本錯誤。
附加:
上述問題,筆者有在github-tensorflow官方進行問題提問,成員 skye 給了筆者一個地址作為參考,地址如下:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
這個地址中給出了checkpoint_convert的詳細代碼,內涵不同版本之間不同命名的轉化問題。
總結
以上是生活随笔為你收集整理的【Ubuntu-Tensorflow】TF1.0到TF1.2出现“Key LSTM/basic_lstm_cell/bias not found in checkpoin”问题的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 一秒钟都不会锻炼,但握力却是所有人的 2
- 下一篇: 【Ubuntu-Opencv】Ubunt