简体   繁体   English

加载模型并将其用于在Tensorflow中训练其他模型

[英]Loading a model and using it for training the other model in Tensorflow

I have trained a model as a Color_Model in Tensorflow and it works fine. 我已经在Tensorflow中将模型训练为Color_Model,并且效果很好。 I want to use this trained model for training another model as Motion_Model. 我想使用这个训练有素的模型来训练另一个模型作为Motion_Model。 Actually the outut of the Color_Model goes into the Motion_Model helps the training of the Motion_Model. 实际上,Color_Model的输出进入Motion_Model有助于训练Motion_Model。 But the problem is I do not know how to load the Color_Model graph and set up the Motion_Model graph so the that the Tensorflow knows that they are separate. 但是问题是我不知道如何加载Color_Model图和设置Motion_Model图,以便Tensorflow知道它们是分开的。 I changed the name of the weights in the Motion_Model so they do not have any name conflict. 我在Motion_Model中更改了权重的名称,因此它们没有任何名称冲突。

Here is a part of the code for loading and training: 这是加载和训练代码的一部分:

with tf.Session() as sess:
  sess.run(init_op) 
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)   
  ###Loaing the color model 
  new_saver = tf.train.import_meta_graph('./Color_Model/Deep_CNN_Color_Arch16.ckpt-44.meta')

  new_saver.restore(sess,tf.train.latest_checkpoint('./Color_Model/'))
  graph = tf.get_default_graph()
  X = graph.get_tensor_by_name("X:0")
  Y = graph.get_tensor_by_name("Y:0")
  phase = graph.get_tensor_by_name("phase:0")
  A7 = graph.get_tensor_by_name("Finalo:0")
  ##########################
  ###Training phase
  for step in range(1, iterations+1):
     ###Getting the training data batch
     img = sess.run([image])    
     X_temp = img[0][:,:,:,0:8]
     Y_temp = img[0][:,:,:,8:9]
     X_temp = X_temp.astype(np.float32)/255
     Y_temp = Y_temp.astype(np.float32)/255
     ###Getting the color model result
     output = sess.run([A7], feed_dict = {X: X_temp[:,:,:,5:8], Y: Y_temp, phase: False})
     ###Training the motion model       
     _, c, outputM = sess.run([optimizer, costM, MN_out], feed_dict = {XM: X_temp[:,:,:,0:5], YM: Y_temp, phaseM: True, ZM: output})

As you can see, the first "sess.run" runs the Color_Model to get the output from it, and the second "sess.run" gets this output and feeds it to the Motion_Model for training it. 如您所见,第一个“ sess.run”运行Color_Model以从中获取输出,第二个“ sess.run”获取此输出并将其馈送到Motion_Model进行训练。

But when I run this code I get the following error: 但是,当我运行此代码时,出现以下错误:

    Traceback (most recent call last):
    File "/home/hamidreza/venv/lib/python3.5/site- 
    packages/tensorflow/python/client/session.py", line 1292, in _do_call
    return fn(*args)
    File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1277, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
    File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1367, in _call_tf_sessionrun
run_metadata)
    tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

 During handling of the above exception, another exception occurred:

 Traceback (most recent call last):
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
{self.saver_def.filename_tensor_name: save_path})
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
 tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

 Caused by op 'save/RestoreV2', defined at:
 File "Detection_Model1.py", line 52, in <module>
 saver = tf.train.Saver()
 File "/home/hamidreza/venv/lib/python3.5/site- 
 packages/tensorflow/python/training/saver.py", line 1094, in __init__
 self.build()
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1106, in build
 self._build(self._filename, build_save=True, build_restore=True)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1143, in _build
 build_save=build_save, build_restore=build_restore)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 787, in _build_internal
 restore_sequentially, reshape)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 406, in _AddRestoreOps
restore_sequentially)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 854, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1466, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op
op_def=op_def)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1768, in __init__
self._traceback = tf_stack.extract_stack()


 NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. 
 Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

 Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

I am quite sure that this is mixing the graphs because WM1 is the weight of the first layer in the Motion_Model and actually the error is saying that it can not find it in the checkpoint which refers to the Color_Model I guess. 我非常确定这是混合的图形,因为WM1是Motion_Model中第一层的权重,实际上错误是说它无法在引用我猜为Color_Model的检查点中找到它。 I really appreciate if you help me with this problem. 如果您能帮助我解决这个问题,我非常感谢。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM