[英]How can I reroute the training input pipeline to test pipeline in tensorflow using tf.contrib.graph_editor?
Suppose now I have a training input pipeline which finally generate train_x
and train_y
using tf.train.shuffle_batch
. 现在假设我有一个训练输入管道并最终产生
train_x
和train_y
使用tf.train.shuffle_batch
。 I export meta graph and re-import the graph in another code file. 我导出元图,然后将图重新导入另一个代码文件中。 Now I want to detach the input pipeline, ie, the
train_x
and train_y
, and connect a new test_x
and test_y
. 现在,我想分离输入管道,即
train_x
和train_y
,并连接新的test_x
和test_y
。 How can I make accomplish this using tf.contrib.graph_editor? 如何使用tf.contrib.graph_editor完成此操作?
EDIT: As suggested by @iga, I change my input directory using input_map
编辑:由@iga建议,我使用
input_map
更改输入目录
filenames = tf.train.match_filenames_once(FLAGS.data_dir + '*', name='matching_filenames')
if FLAGS.ckpt != '':
latest = FLAGS.log_dir + FLAGS.ckpt
else:
latest = tf.train.latest_checkpoint(FLAGS.log_dir)
if not latest or not os.path.exists(latest+'.meta'):
print("checkpoint " + latest + " does not exist")
sys.exit(1)
saver = tf.train.import_meta_graph(latest+'.meta',
input_map={'matching_filenames:0':filenames},
import_scope='import')
g = tf.get_default_graph()
but I get the following error: 但我收到以下错误:
ValueError: graph_def is invalid at node u'matching_filenames/Assign': Input tensor 'matching_filenames:0' Cannot convert a tensor of type string to an input of type string_ref.
ValueError:graph_def在节点u'matching_filenames / Assign'处无效:输入张量'matching_filenames:0'无法将字符串类型的张量转换为字符串类型的输入。 Are there any elegant way to resolve this?
有什么优雅的方法可以解决这个问题吗?
For this task, you should be able to just use input_map
argument to https://www.tensorflow.org/api_docs/python/tf/import_graph_def . 对于此任务,您应该只可以对https://www.tensorflow.org/api_docs/python/tf/import_graph_def使用
input_map
参数。 If you are using import_meta_graph , you can pass the input_map into its kwargs
and it will get passed down to import_graph_def. 如果您使用import_meta_graph ,则可以将input_map传递到其
kwargs
并将其向下传递到import_graph_def。
RESPONSE TO EDIT : I am assuming that your original graph (the one you are deserializing) had the same matching_filenames
variable. 对编辑的响应 :我假设您的原始图形(您要反序列化的图形)具有相同的
matching_filenames
变量。 Quite confusingly, the tensor name "matching_filenames:0" actually refers to the tensor going from the VariableV2
op to the Assign
op. 令人困惑的是,张量名称“ matching_filenames:0”实际上是指从
VariableV2
op到Assign
op的张量。 The type of this edge is string_ref
and you don't really want to break that edge. 该边缘的类型为
string_ref
,您实际上并不想破坏该边缘。
The output from a variable typically goes through an identity op called matching_filenames/read
. 变量的输出通常会通过一个名为
matching_filenames/read
的标识操作。 This is what you want to use as the key in your input_map. 这就是您要用作input_map中的键的东西。 For the value, you want the same tensor in your new
filenames
. 对于该值,您希望新的
filenames
具有相同的张量。 So, your call should probably look like: 因此,您的呼叫应该看起来像:
tf.train.import_meta_graph(latest+'.meta',
input_map={'matching_filenames/read': filenames.read_value()},
import_scope='import')
In general, variables are fairly complicated. 通常,变量非常复杂。 If this does not work, you can use some placeholder op and feed the names into it manually.
如果这不起作用,则可以使用一些占位符op并将名称手动输入其中。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.