简体   繁体   English

相当于python的C++:Tensorflow中的tf.Graph.get_tensor_by_name()?

[英]C++ equivalent of python: tf.Graph.get_tensor_by_name() in Tensorflow?

What is the C++ equivalent of python: tf.Graph.get_tensor_by_name(name) in Tensorflow? python: tf.Graph.get_tensor_by_name(name) 在 Tensorflow 中的 C++ 等价物是什么? Thanks!谢谢!

Here is the code I am trying to run, but I get an empty output :这是我尝试运行的代码,但output为空:

Status status = NewSession(SessionOptions(), &session); // create new session
ReadBinaryProto(tensorflow::Env::Default(), model, &graph_def); // read Graph
session->Create(graph_def); // add Graph to Tensorflow session 
std::vector<tensorflow::Tensor> output; // create Tensor to store output
std::vector<string> vNames; // vector of names for required graph nodes
vNames.push_back("some_name"); // I checked names and they are presented in loaded Graph

session->Run({}, vNames, {}, &output); // ??? As a result I have empty output

From your comment , it sounds like you are using the C++ tensorflow::Session API, which represents graphs as GraphDef protocol buffers.根据您的评论,听起来您正在使用 C++ tensorflow::Session API,它将图形表示为GraphDef协议缓冲区。 There is no equivalent to tf.Graph.get_tensor_by_name() in this API.此 API 中没有tf.Graph.get_tensor_by_name()等效的方法

Instead of passing typed tf.Tensor objects to Session::Run() , you pass the string names of tensors, which have the form <NODE NAME>:<N> , where <NODE NAME> matches one of the NodeDef.name values in the GraphDef , and <N> is an integer corresponding to the index of the the output from that node that you want to fetch.不是将类型化的tf.Tensor对象传递给Session::Run() ,而是传递张量的string名称,其格式为<NODE NAME>:<N> ,其中<NODE NAME>匹配NodeDef.name值之一在GraphDef<N>是一个整数,对应于您要获取的该节点的输出索引。

The code in your question looks roughly correct, but there are two things I'd advise:您问题中的代码看起来大致正确,但我建议两件事:

  1. The session->Run() call returns a tensorflow::Status value. session->Run()调用返回一个tensorflow::Status值。 If output is empty after the the call returns, it is almost certain that the call returned an error status with a message that explains the problem.如果调用返回后output为空,则几乎可以肯定调用返回了错误状态,并带有解释问题的消息。

  2. You're passing "some_name" as the name of a tensor to fetch, but it is the name of a node, not a tensor.您将"some_name"作为要获取的张量的名称传递,但它是节点的名称,而不是张量。 It is possible that this API requires you to specify the output index explicitly: try replacing it with "some_name:0" .此 API 可能要求您明确指定输出索引:尝试将其替换为"some_name:0"

there is a way to get neural node from graph_def directly.有一种方法可以直接从 graph_def 获取神经节点。 if u only want the shape\\type of node: "some_name":如果你只想要节点的形状\\类型:“some_name”:

void readPB(GraphDef & graph_def)
{

    int i;
    for (i = 0; i < graph_def.node_size(); i++)
    {
        if (graph_def.node(i).name() == "inputx")
        {
            graph_def.node(i).PrintDebugString();
        }
    }
}

results:结果:

name: "inputx"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: -1
      }
      dim {
        size: 5120
      }
    }
  }
}

try member functins of the node and get the informations.尝试节点的成员函数并获取信息。

In case anybody is interested, here is how you extract the shape of an arbitrary sensor from graph_def using the tensorflow C++ API如果有人感兴趣,这里是如何使用 tensorflow C++ API 从 graph_def 中提取任意传感器的形状

vector<int64_t> get_shape_of_tensor(tensorflow::GraphDef graph_def, std::string name_tensor)
{
    vector<int64_t> tensor_shape;
    for (int i=0; i < graph_def.node_size(); ++i) {
        if (graph_def.node(i).name() == name_tensor) {
            auto node = graph_def.node(i);
            auto attr_map = node.attr();
            for (auto it=attr_map.begin(); it != attr_map.end(); it++) {
                auto key = it->first;
                auto value = it->second;
                if (value.has_shape()) {
                    auto shape = value.shape();
                    for (int i=0; i<shape.dim_size(); ++i) {
                        auto dim = shape.dim(i);
                        auto dim_size = dim.size();
                        tensor_shape.push_back(dim_size);
                    }
                }
            }
        }
    }
    return tensor_shape
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM