簡體   English   中英

相當於python的C++:Tensorflow中的tf.Graph.get_tensor_by_name()?

[英]C++ equivalent of python: tf.Graph.get_tensor_by_name() in Tensorflow?

python: tf.Graph.get_tensor_by_name(name) 在 Tensorflow 中的 C++ 等價物是什么? 謝謝!

這是我嘗試運行的代碼,但output為空:

Status status = NewSession(SessionOptions(), &session); // create new session
ReadBinaryProto(tensorflow::Env::Default(), model, &graph_def); // read Graph
session->Create(graph_def); // add Graph to Tensorflow session 
std::vector<tensorflow::Tensor> output; // create Tensor to store output
std::vector<string> vNames; // vector of names for required graph nodes
vNames.push_back("some_name"); // I checked names and they are presented in loaded Graph

session->Run({}, vNames, {}, &output); // ??? As a result I have empty output

根據您的評論,聽起來您正在使用 C++ tensorflow::Session API,它將圖形表示為GraphDef協議緩沖區。 此 API 中沒有tf.Graph.get_tensor_by_name()等效的方法

不是將類型化的tf.Tensor對象傳遞給Session::Run() ,而是傳遞張量的string名稱,其格式為<NODE NAME>:<N> ,其中<NODE NAME>匹配NodeDef.name值之一在GraphDef<N>是一個整數,對應於您要獲取的該節點的輸出索引。

您問題中的代碼看起來大致正確,但我建議兩件事:

  1. session->Run()調用返回一個tensorflow::Status值。 如果調用返回后output為空,則幾乎可以肯定調用返回了錯誤狀態,並帶有解釋問題的消息。

  2. 您將"some_name"作為要獲取的張量的名稱傳遞,但它是節點的名稱,而不是張量。 此 API 可能要求您明確指定輸出索引:嘗試將其替換為"some_name:0"

有一種方法可以直接從 graph_def 獲取神經節點。 如果你只想要節點的形狀\\類型:“some_name”:

void readPB(GraphDef & graph_def)
{

    int i;
    for (i = 0; i < graph_def.node_size(); i++)
    {
        if (graph_def.node(i).name() == "inputx")
        {
            graph_def.node(i).PrintDebugString();
        }
    }
}

結果:

name: "inputx"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: -1
      }
      dim {
        size: 5120
      }
    }
  }
}

嘗試節點的成員函數並獲取信息。

如果有人感興趣,這里是如何使用 tensorflow C++ API 從 graph_def 中提取任意傳感器的形狀

vector<int64_t> get_shape_of_tensor(tensorflow::GraphDef graph_def, std::string name_tensor)
{
    vector<int64_t> tensor_shape;
    for (int i=0; i < graph_def.node_size(); ++i) {
        if (graph_def.node(i).name() == name_tensor) {
            auto node = graph_def.node(i);
            auto attr_map = node.attr();
            for (auto it=attr_map.begin(); it != attr_map.end(); it++) {
                auto key = it->first;
                auto value = it->second;
                if (value.has_shape()) {
                    auto shape = value.shape();
                    for (int i=0; i<shape.dim_size(); ++i) {
                        auto dim = shape.dim(i);
                        auto dim_size = dim.size();
                        tensor_shape.push_back(dim_size);
                    }
                }
            }
        }
    }
    return tensor_shape
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM