繁体   English   中英

在张量流中导入图形时使用新的op

[英]Using new op while importing graph in tensorflow

我是TensorFlow的新手。 我正在尝试使用检查点文件导入训练有素的TensorFlow网络。 我正在使用的网络有一个自定义操作,当我在Python中使用它时可以正常工作。 但是,我必须冻结图形,因为我必须使用C ++ API。 我正在使用TensorFlow基目录中的以下命令调用freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=<local path>/data/graph_vgg.pb --input_checkpoint=<local path>/data/VGGnet_fast_rcnn_iter_70000.ckpt --output_node_names="cls_prob,bbox_pred" --output_graph=<local path>/graph_frozen.pb

但是,当我试图冻结图形时,我收到以下错误。

Traceback (most recent call last):
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 202, in <module>
    app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 134, in main
    FLAGS.variable_names_blacklist)
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 99, in freeze_graph
    _ = importer.import_graph_def(input_graph_def, name="")
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/framework/importer.py", line 260, in import_graph_def
    raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named RoiPool in defined operations.

输入图有一个类型为RoiPool的op的节点,TensorFlow无法识别。 我调查了抛出此错误的代码,它看起来像是未在TensorFlow中注册的操作。 我和我一起构建了.so文件。 我应该把它复制到某个地方吗? 我在网上找不到那样的东西。 任何帮助或指针都会很棒。 我花了很多时间来解决这个问题。 代码在python中工作正常,使用op的层在项目目录中。 请帮助我理解我需要做些什么来使它工作。

编辑:这是网络中使用的自定义操作代码

我不熟悉那个特定的RoiPooling实现,但我通常设置一个需要冻结的自定义op的方式是roi_pooling_op.cc和相关的python文件(定义渐变并导入* .so)都位于// tensorflow / user_ops中。

// tensorflow / user_ops目录中的BUILD文件应具有

tf_custom_op_library(
    name = "roi_pooling_op.so",
    srcs = ["roipooling_op.cc"],
)

py_library(
    name = "roi_pooling_op_py",
    srcs = ["roi_pooling.py"],
    data = [":roi_pooling_op.so"],
    srcs_version = "PY2AND3",
)

* Tensorflow文档中没有提到data = [":roi_pooling_op.so"] ,但是你不必深入了解本地的bazel-bin目录,而是可以使用tf.resource_loader.get_path_to_datafile导入*。所以

_roi_pooling_module = tf.load_op_library(tf.resource_loader.get_path_to_datafile("roi_pooling_op.so"))
roi_pool = _roi_pooling_module.roi_pool
roi_pool_grad = _roi_pooling_module.roi_pool_grad

@ops.RegisterGradient("RoiPool")
def _roi_pool_grad(op, grad, _):
    grad_out = roi_pool_grad(...)
    return grad_out, None

更新冻结构建,在BUILD文件// tensorflow / python / tools目录中,添加"//tensorflow/user_ops:roi_pooling_op_py",作为对freeze_graph py_binary的依赖。

最后重新构建并安装所有内容(custom-op,freeze_graph和pip包/ wheel)

bazel build --config opt //tensorflow/user_ops:roi_pooling_op.so
bazel build --config opt //tensorflow/user_ops:roi_pooling_op_py
bazel build --config opt //tensorflow/python/tools:freeze_graph
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg

pip install --ignore-installed --upgrade /tmp/tensorflow_pkg/tensorflow-1.2.1-py2-none-any.whl

现在你可以在你的python代码中使用它了

from tensorflow.user_ops import roi_pooling

现在您应该可以冻结图形而不会出现任何问题。

我跟着贾里德的回答,我认为它让我大部分时间,但我需要最后一件来自https://stackoverflow.com/a/37556646/7004026 我插入tf.load_op_library('/path/to/custom_op.so')右之前调用import_graph_def直接freeze_graph.py 然后我就能冻结我的图表了。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM