简体   繁体   English

如何使用tf.contrib.graph_editor重新路由训练输入管道以测试tensorflow中的管道?

[英]How can I reroute the training input pipeline to test pipeline in tensorflow using tf.contrib.graph_editor?

Suppose now I have a training input pipeline which finally generate train_x and train_y using tf.train.shuffle_batch . 现在假设我有一个训练输入管道并最终产生train_xtrain_y使用tf.train.shuffle_batch I export meta graph and re-import the graph in another code file. 我导出元图,然后将图重新导入另一个代码文件中。 Now I want to detach the input pipeline, ie, the train_x and train_y , and connect a new test_x and test_y . 现在,我想分离输入管道,即train_xtrain_y ,并连接新的test_xtest_y How can I make accomplish this using tf.contrib.graph_editor? 如何使用tf.contrib.graph_editor完成此操作?

EDIT: As suggested by @iga, I change my input directory using input_map 编辑:由@iga建议,我使用input_map更改输入目录

filenames = tf.train.match_filenames_once(FLAGS.data_dir + '*', name='matching_filenames')
if FLAGS.ckpt != '':
    latest = FLAGS.log_dir + FLAGS.ckpt
else:
    latest = tf.train.latest_checkpoint(FLAGS.log_dir)
if not latest or not os.path.exists(latest+'.meta'):
    print("checkpoint " + latest + " does not exist")
    sys.exit(1)
saver = tf.train.import_meta_graph(latest+'.meta', 
                                   input_map={'matching_filenames:0':filenames},
                                   import_scope='import')
g = tf.get_default_graph() 

but I get the following error: 但我收到以下错误:

ValueError: graph_def is invalid at node u'matching_filenames/Assign': Input tensor 'matching_filenames:0' Cannot convert a tensor of type string to an input of type string_ref. ValueError:graph_def在节点u'matching_filenames / Assign'处无效:输入张量'matching_filenames:0'无法将字符串类型的张量转换为字符串类型的输入。 Are there any elegant way to resolve this? 有什么优雅的方法可以解决这个问题吗?

For this task, you should be able to just use input_map argument to https://www.tensorflow.org/api_docs/python/tf/import_graph_def . 对于此任务,您应该只可以对https://www.tensorflow.org/api_docs/python/tf/import_graph_def使用input_map参数。 If you are using import_meta_graph , you can pass the input_map into its kwargs and it will get passed down to import_graph_def. 如果您使用import_meta_graph ,则可以将input_map传递到其kwargs并将其向下传递到import_graph_def。

RESPONSE TO EDIT : I am assuming that your original graph (the one you are deserializing) had the same matching_filenames variable. 对编辑的响应 :我假设您的原始图形(您要反序列化的图形)具有相同的matching_filenames变量。 Quite confusingly, the tensor name "matching_filenames:0" actually refers to the tensor going from the VariableV2 op to the Assign op. 令人困惑的是,张量名称“ matching_filenames:0”实际上是指从VariableV2 op到Assign op的张量。 The type of this edge is string_ref and you don't really want to break that edge. 该边缘的类型为string_ref ,您实际上并不想破坏该边缘。

The output from a variable typically goes through an identity op called matching_filenames/read . 变量的输出通常会通过一个名为matching_filenames/read的标识操作。 This is what you want to use as the key in your input_map. 这就是您要用作input_map中的键的东西。 For the value, you want the same tensor in your new filenames . 对于该值,您希望新的filenames具有相同的张量。 So, your call should probably look like: 因此,您的呼叫应该看起来像:

tf.train.import_meta_graph(latest+'.meta', 
                           input_map={'matching_filenames/read': filenames.read_value()},
                           import_scope='import')

In general, variables are fairly complicated. 通常,变量非常复杂。 If this does not work, you can use some placeholder op and feed the names into it manually. 如果这不起作用,则可以使用一些占位符op并将名称手动输入其中。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用tf.contrib.graph_editor克隆网络 - Cloning a network with tf.contrib.graph_editor 在测试中替换输入管道(不带占位符的tf.contrib.data) - Replacing input pipeline at test (tf.contrib.data without placeholders) 使用带有skflow / tf学习功能的Tensorflow输入管道 - Using a Tensorflow input pipeline with skflow/tf learn TF 2 API中的tensorflow.contrib.graph_editor? - tensorflow.contrib.graph_editor in TF 2 API? Tensorflow输入管道用于分布式培训 - Tensorflow input pipeline for distributed training 如何在tensorflow中使用tf.RandomShuffleQueue和tf.train.shuffle_batch制作输入管道? - How to make input pipeline using tf.RandomShuffleQueue and tf.train.shuffle_batch in tensorflow? tensorflow-具有多个TFRecord文件+ tf.contrib.data.sliding_window_batch()的输入管道 - tensorflow - Input pipeline with multiple TFRecord files + tf.contrib.data.sliding_window_batch() 使用 tf.import_graph_def 附加新输入管道时如何避免图形重复? - How to avoid graph duplication when using tf.import_graph_def to append a new input pipeline? Tensorflow:使用输入管道(.csv)作为字典进行培训 - Tensorflow: using an input-pipeline (.csv) as a dictionary for training 尽管使用了 tf 数据管道,但训练速度很慢 - slow training despite using tf data pipeline
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM