[英]Loading a model and using it for training the other model in Tensorflow
I have trained a model as a Color_Model in Tensorflow and it works fine. 我已经在Tensorflow中将模型训练为Color_Model,并且效果很好。 I want to use this trained model for training another model as Motion_Model.
我想使用这个训练有素的模型来训练另一个模型作为Motion_Model。 Actually the outut of the Color_Model goes into the Motion_Model helps the training of the Motion_Model.
实际上,Color_Model的输出进入Motion_Model有助于训练Motion_Model。 But the problem is I do not know how to load the Color_Model graph and set up the Motion_Model graph so the that the Tensorflow knows that they are separate.
但是问题是我不知道如何加载Color_Model图和设置Motion_Model图,以便Tensorflow知道它们是分开的。 I changed the name of the weights in the Motion_Model so they do not have any name conflict.
我在Motion_Model中更改了权重的名称,因此它们没有任何名称冲突。
Here is a part of the code for loading and training: 这是加载和训练代码的一部分:
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
###Loaing the color model
new_saver = tf.train.import_meta_graph('./Color_Model/Deep_CNN_Color_Arch16.ckpt-44.meta')
new_saver.restore(sess,tf.train.latest_checkpoint('./Color_Model/'))
graph = tf.get_default_graph()
X = graph.get_tensor_by_name("X:0")
Y = graph.get_tensor_by_name("Y:0")
phase = graph.get_tensor_by_name("phase:0")
A7 = graph.get_tensor_by_name("Finalo:0")
##########################
###Training phase
for step in range(1, iterations+1):
###Getting the training data batch
img = sess.run([image])
X_temp = img[0][:,:,:,0:8]
Y_temp = img[0][:,:,:,8:9]
X_temp = X_temp.astype(np.float32)/255
Y_temp = Y_temp.astype(np.float32)/255
###Getting the color model result
output = sess.run([A7], feed_dict = {X: X_temp[:,:,:,5:8], Y: Y_temp, phase: False})
###Training the motion model
_, c, outputM = sess.run([optimizer, costM, MN_out], feed_dict = {XM: X_temp[:,:,:,0:5], YM: Y_temp, phaseM: True, ZM: output})
As you can see, the first "sess.run" runs the Color_Model to get the output from it, and the second "sess.run" gets this output and feeds it to the Motion_Model for training it. 如您所见,第一个“ sess.run”运行Color_Model以从中获取输出,第二个“ sess.run”获取此输出并将其馈送到Motion_Model进行训练。
But when I run this code I get the following error: 但是,当我运行此代码时,出现以下错误:
Traceback (most recent call last):
File "/home/hamidreza/venv/lib/python3.5/site-
packages/tensorflow/python/client/session.py", line 1292, in _do_call
return fn(*args)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1277, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1367, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
Caused by op 'save/RestoreV2', defined at:
File "Detection_Model1.py", line 52, in <module>
saver = tf.train.Saver()
File "/home/hamidreza/venv/lib/python3.5/site-
packages/tensorflow/python/training/saver.py", line 1094, in __init__
self.build()
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1106, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1143, in _build
build_save=build_save, build_restore=build_restore)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 787, in _build_internal
restore_sequentially, reshape)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 406, in _AddRestoreOps
restore_sequentially)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 854, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1466, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op
op_def=op_def)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1768, in __init__
self._traceback = tf_stack.extract_stack()
NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint.
Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
I am quite sure that this is mixing the graphs because WM1 is the weight of the first layer in the Motion_Model and actually the error is saying that it can not find it in the checkpoint which refers to the Color_Model I guess. 我非常确定这是混合的图形,因为WM1是Motion_Model中第一层的权重,实际上错误是说它无法在引用我猜为Color_Model的检查点中找到它。 I really appreciate if you help me with this problem.
如果您能帮助我解决这个问题,我非常感谢。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.