[英]Using new op while importing graph in tensorflow
我是TensorFlow的新手。 我正在尝试使用检查点文件导入训练有素的TensorFlow网络。 我正在使用的网络有一个自定义操作,当我在Python中使用它时可以正常工作。 但是,我必须冻结图形,因为我必须使用C ++ API。 我正在使用TensorFlow基目录中的以下命令调用freeze_graph
:
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=<local path>/data/graph_vgg.pb --input_checkpoint=<local path>/data/VGGnet_fast_rcnn_iter_70000.ckpt --output_node_names="cls_prob,bbox_pred" --output_graph=<local path>/graph_frozen.pb
但是,当我试图冻结图形时,我收到以下错误。
Traceback (most recent call last):
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 202, in <module>
app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 134, in main
FLAGS.variable_names_blacklist)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 99, in freeze_graph
_ = importer.import_graph_def(input_graph_def, name="")
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/framework/importer.py", line 260, in import_graph_def
raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named RoiPool in defined operations.
输入图有一个类型为RoiPool
的op的节点,TensorFlow无法识别。 我调查了抛出此错误的代码,它看起来像是未在TensorFlow中注册的操作。 我和我一起构建了.so
文件。 我应该把它复制到某个地方吗? 我在网上找不到那样的东西。 任何帮助或指针都会很棒。 我花了很多时间来解决这个问题。 代码在python中工作正常,使用op的层在项目目录中。 请帮助我理解我需要做些什么来使它工作。
我不熟悉那个特定的RoiPooling实现,但我通常设置一个需要冻结的自定义op的方式是roi_pooling_op.cc和相关的python文件(定义渐变并导入* .so)都位于// tensorflow / user_ops中。
// tensorflow / user_ops目录中的BUILD文件应具有
tf_custom_op_library(
name = "roi_pooling_op.so",
srcs = ["roipooling_op.cc"],
)
py_library(
name = "roi_pooling_op_py",
srcs = ["roi_pooling.py"],
data = [":roi_pooling_op.so"],
srcs_version = "PY2AND3",
)
* Tensorflow文档中没有提到data = [":roi_pooling_op.so"]
,但是你不必深入了解本地的bazel-bin目录,而是可以使用tf.resource_loader.get_path_to_datafile
导入*。所以
_roi_pooling_module = tf.load_op_library(tf.resource_loader.get_path_to_datafile("roi_pooling_op.so"))
roi_pool = _roi_pooling_module.roi_pool
roi_pool_grad = _roi_pooling_module.roi_pool_grad
@ops.RegisterGradient("RoiPool")
def _roi_pool_grad(op, grad, _):
grad_out = roi_pool_grad(...)
return grad_out, None
更新冻结构建,在BUILD文件// tensorflow / python / tools目录中,添加"//tensorflow/user_ops:roi_pooling_op_py",
作为对freeze_graph py_binary的依赖。
最后重新构建并安装所有内容(custom-op,freeze_graph和pip包/ wheel)
bazel build --config opt //tensorflow/user_ops:roi_pooling_op.so
bazel build --config opt //tensorflow/user_ops:roi_pooling_op_py
bazel build --config opt //tensorflow/python/tools:freeze_graph
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install --ignore-installed --upgrade /tmp/tensorflow_pkg/tensorflow-1.2.1-py2-none-any.whl
现在你可以在你的python代码中使用它了
from tensorflow.user_ops import roi_pooling
现在您应该可以冻结图形而不会出现任何问题。
我跟着贾里德的回答,我认为它让我大部分时间,但我需要最后一件来自https://stackoverflow.com/a/37556646/7004026 。 我插入tf.load_op_library('/path/to/custom_op.so')
右之前调用import_graph_def
直接freeze_graph.py
。 然后我就能冻结我的图表了。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.