简体   繁体   中英

Using new op while importing graph in tensorflow

I'm new to TensorFlow. I'm trying to import a trained TensorFlow network with checkpoint files. The network which I'm using has a custom op which works fine when I'm using it in Python. However, I have to freeze the graph because I have to use the C++ API. I'm invoking freeze_graph with the following command from the TensorFlow base directory:

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=<local path>/data/graph_vgg.pb --input_checkpoint=<local path>/data/VGGnet_fast_rcnn_iter_70000.ckpt --output_node_names="cls_prob,bbox_pred" --output_graph=<local path>/graph_frozen.pb

But, I'm getting the following error when I'm trying to freeze the graph.

Traceback (most recent call last):
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 202, in <module>
    app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 134, in main
    FLAGS.variable_names_blacklist)
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 99, in freeze_graph
    _ = importer.import_graph_def(input_graph_def, name="")
  File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/framework/importer.py", line 260, in import_graph_def
    raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named RoiPool in defined operations.

The input graph has a node with an op of type RoiPool , which TensorFlow does not recognize. I investigated the code which throws this error and it looks something like the op is not registered with TensorFlow. I have the built .so file with me. Am I supposed to copy it somewhere? I couldn't find anything like that online though. Any help or pointers would be great. I've spent a lot of time on this problem. The code works fine in python and the layer which uses the op is in the project directory. Please help me understand what I need to do to make it work.

Edit: This is the code of custom op that is used in the network.

I'm not familar with that specific RoiPooling implementation, but the way I typically setup a custom op that needs frozen is both roi_pooling_op.cc and the associated python file (defines the gradient and imports the *.so) are both located in //tensorflow/user_ops.

The BUILD file in the //tensorflow/user_ops directory should have

tf_custom_op_library(
    name = "roi_pooling_op.so",
    srcs = ["roipooling_op.cc"],
)

py_library(
    name = "roi_pooling_op_py",
    srcs = ["roi_pooling.py"],
    data = [":roi_pooling_op.so"],
    srcs_version = "PY2AND3",
)

* the data = [":roi_pooling_op.so"] isn't mentioned in the Tensorflow docs, but it's so you don't have to dig through your local bazel-bin directory and instead can use tf.resource_loader.get_path_to_datafile to import the *.so

_roi_pooling_module = tf.load_op_library(tf.resource_loader.get_path_to_datafile("roi_pooling_op.so"))
roi_pool = _roi_pooling_module.roi_pool
roi_pool_grad = _roi_pooling_module.roi_pool_grad

@ops.RegisterGradient("RoiPool")
def _roi_pool_grad(op, grad, _):
    grad_out = roi_pool_grad(...)
    return grad_out, None

Update the freeze build, in the BUILD file //tensorflow/python/tools directory, add "//tensorflow/user_ops:roi_pooling_op_py", as a dependency to the freeze_graph py_binary.

Lastly re-build and install everything (custom-op, freeze_graph and pip package/wheel)

bazel build --config opt //tensorflow/user_ops:roi_pooling_op.so
bazel build --config opt //tensorflow/user_ops:roi_pooling_op_py
bazel build --config opt //tensorflow/python/tools:freeze_graph
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg

pip install --ignore-installed --upgrade /tmp/tensorflow_pkg/tensorflow-1.2.1-py2-none-any.whl

Now you can use it in your python code with

from tensorflow.user_ops import roi_pooling

Now you should be able to freeze the graph without any issues.

I followed Jared's answer and I think it got me most of the way but I needed one last piece from https://stackoverflow.com/a/37556646/7004026 . I inserted tf.load_op_library('/path/to/custom_op.so') right before the call to import_graph_def directly in freeze_graph.py . Then I was able to freeze my graph.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM