简体   繁体   中英

Combining graphs: is there a TensorFlow import_graph_def equivalent for C++?

I need to extend exported models with a custom input and output layer. I have found out this can easily be done with:

with tf.Graph().as_default() as g1: # actual model
    in1 = tf.placeholder(tf.float32,name="input")
    ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
    in2 = tf.placeholder(tf.float32,name="input")
    ou2 = tf.add(in2,2.0,name="output")

gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()

with tf.Graph().as_default() as g_combined: #merge together
    x = tf.placeholder(tf.float32, name="actual_input") # the new input layer

    # Import gdef_1, which performs f(x).
    # "input:0" and "output:0" are the names of tensors in gdef_1.
    y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                             return_elements=["output:0"])

    # Import gdef_2, which performs g(y)
    z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                             return_elements=["output:0"])

sess = tf.Session(graph=g_combined)

print "result is: ", sess.run(z, {"actual_input:0":5}) #result is: 9

this works fine.

However instead of passing a dataset in arbitrary shape, I need to give a pointer as network input. The problem is, I can't think of any solution for this inside python (defining and passing a pointer), and when developing a network with the C++ Api I can't find an equivalent to the tf.import_graph_def function.

Does this have a different name in C++ or is there an other way to merge two graphs/models in C++?

Thanks for any advice

It is not as easy as in Python.

You can load a GraphDef with something like this:

#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>

tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
    tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) { /* Error... */ }

Then you can use it to create a session:

#include <tensorflow/core/public/session.h>

tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) { /* Error... */ }
status = session->Create(graph);
if (!status.ok()) { /* Error... */ }

Or to extend the graph of an existing one:

status = session->Extend(graph);
if (!status.ok()) { /* Error... */ }

This way you can put several GraphDef s into the same graph. However, there are no additional facilities to extract particular nodes, nor to avoid names collisions - you have to find the nodes yourself and you have to ensure that the GraphDef s do not have conflicting op names. As an example, I use this function to find all the nodes with a name matching a given regular expression, sorted by name:

#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>

std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex &regex)
{
    std::vector<const tensorflow::NodeDef *> nodes;
    for (const auto &node : graph.node())
    {
        if (std::regex_match(node.name(), regex))
        {
            nodes.push_back(&node);
        }
    }
    std::sort(nodes.begin(), nodes.end(),
              [](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
              {
                  return lhs->name() < rhs->name();
              });
    return nodes;
}

This can be achieved in C++ by directly manipulating the NodeDefs in the GraphDefs of the two Graphs to be combined. The basic algorithm is to define the two GraphDefs, using Placeholders for the inputs to the second GraphDef and redirect them to the outputs from the first GraphDef. This would be analogous to combining two electric circuits in series by connecting the inputs of the second circuit to the outputs of the first circuit.

First, the sample GraphDefs are defined, as well as a utility for observing the internals of the GraphDefs. It is important to note that all nodes from both GraphDefs must have unique names.

Status Panel::SampleFirst(GraphDef *graph_def) 
{
    Scope root = Scope::NewRootScope();
    Placeholder p1(root.WithOpName("p1"), DT_INT32);
    Placeholder p2(root.WithOpName("p2"), DT_INT32);
    Add add(root.WithOpName("add"), p1, p2);
    return root.ToGraphDef(graph_def);
}

Status Panel::SampleSecond(GraphDef *graph_def)
{
    Scope root = Scope::NewRootScope();
    Placeholder q1(root.WithOpName("q1"), DT_INT32);
    Placeholder q2(root.WithOpName("q2"), DT_INT32);
    Add sum(root.WithOpName("sum"), q1, q2);
    Multiply multiply(root.WithOpName("multiply"), sum, 4);
    return root.ToGraphDef(graph_def);
}

void Panel::ShowGraphDef(GraphDef &graph_def)
{
    for (int i = 0; i < graph_def.node_size(); i++) {
        NodeDef node_def = graph_def.node(i);
        cout << "NodeDef name is " << node_def.name() << endl;
        cout << "NodeDef op is " << node_def.op() << endl;
        for (const string& input : node_def.input()) {
            cout << "\t input: " << input << endl;
        }
    }
}

Now the two GraphDefs are created and the inputs of the second GraphDef are connected to the outputs of the first GraphDef. This is done by iterating over the nodes and identifying the first operational node, whose inputs are the Placeholders, and redirecting those inputs to the outputs of the first GraphDef. The node is then added to the first GraphDef, as well as all subsequent nodes. The result is the first GraphDef appended by the second GraphDef.

Status Panel::Append(vector<Tensor> *outputs)
{
    GraphDef graph_def_first;
    GraphDef graph_def_second;
    TF_RETURN_IF_ERROR(SampleFirst(&graph_def_first));
    TF_RETURN_IF_ERROR(SampleSecond(&graph_def_second));

    for (int i = 0; i < graph_def_second.node_size(); i++) {
        NodeDef node_def = graph_def_second.node(i);
        if (node_def.name() == "sum") {
            node_def.set_input(0, "p1");
            node_def.set_input(1, "add");
        }
        *graph_def_first.add_node() = node_def;
    }

    ShowGraphDef(graph_def_first);

    unique_ptr<Session> session(NewSession(SessionOptions()));
    TF_RETURN_IF_ERROR(session->Create(graph_def_first));

    Tensor t1(2);
    Tensor t2(3);
    vector<pair<string, Tensor>> inputs = {{"p1", t1}, {"p2", t2}};

    TF_RETURN_IF_ERROR(session->Run(inputs, {"multiply"}, {}, outputs));

    return Status::OK();
}

This particular Graph will take two inputs, 2 and 3, and Add them together. Then the sum of that (5) will be added again to the first input (2), then Multiply by 4 to obtain the result of 28. ((2+3) + 2) * 4 = 28.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM