简体   繁体   English

Python中的Tensorflow Java Api`toGraphDef`等价是什么?

[英]What is the Tensorflow Java Api `toGraphDef` equivalent in Python?

I am using the Tensorflow Java Api to load an already created Tensorflow model into the JVM. 我正在使用Tensorflow Java Api将已经创建的Tensorflow模型加载到JVM中。 I am using this as an example: tensorflow/examples/LabelImage.java 我以这个为例: tensorflow / examples / LabelImage.java

Here is my simple scala code: 这是我简单的scala代码:

import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))

How do I save my model to get both the Session and the Graph stored in the same file. 如何保存我的模型,以便将会话和图都存储在同一文件中。 as described in the "PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb" above. 如上文“ PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述。

Described here it mentions: 在这里描述它提到:

The serialized representation of the graph, often referred to as a GraphDef, can be generated by toGraphDef() and equivalents in other language APIs. 图的序列化表示形式(通常称为GraphDef)可以由toGraphDef()和其他语言API中的等效项生成。

What are the equivalents in other language APIs? 其他语言API的等效项是什么? I dont find it obvious 我觉得不明显

Note: I already looked at the mnist_saved_model.py under tensorflow_serving but saving it through that procedure gives me a .pb file and a variables folder. 注意:我已经看过tensorflow_serving下的mnist_saved_model.py,但是通过该过程保存后,得到的是.pb文件和variables文件夹。 When trying to load that .pb file I get: java.lang.IllegalArgumentException: Invalid GraphDef 尝试加载该.pb文件时,我得到: java.lang.IllegalArgumentException: Invalid GraphDef

Currently with the Java API of tensorflow, I only found how to save a graph as a graphDef (ie without its variables and meta-data). 当前使用tensorflow的Java API,我仅发现了如何将图形另存为graphDef(即没有其变量和元数据)。 This can be done by just writing the Array[Byte] to a file: 只需将Array [Byte]写入文件即可完成:

Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)

Here myGraph is a java object from the Graph class . 这里的myGraphGraph类的java对象。

I would suggest to save your model from the Python API, using the SavedModel api defined here. 我建议使用此处定义的SavedModel api从Python API保存模型。 It will save your model in a folder with both the serialized graph in a .pb file and the variables in a folder. 它将模型保存在一个文件夹中,而序列化图形保存在.pb文件中,变量保存在文件夹中。 Note the tag_constants you use as you'll need it in your scala/java code to load the model with the variables. 请注意,在scala / java代码中需要使用tag_constants来加载带有变量的模型。 Then the graph and session with variables are easily loaded with the SavedModelBundle java class from the java api. 然后,可以使用来自Java api的SavedModelBundle Java类轻松加载带有变量的图形和会话。 It returns you a wrapper with both the graph and the session containing the variables values: 它返回一个包装,其中包含图形和包含变量值的会话:

val model = SavedModelBundle.load(modelDir, modelTag)

If you already tried this, maybe you can share your code to see why it returned an invalid GraphDef. 如果您已经尝试过此操作,也许可以共享您的代码以查看为什么它返回无效的GraphDef。

Another option is to freeze your graph, ie you turned your variable nodes into constant Nodes so everything is self-contained in the .pb file. 另一个选择是冻结图形,即将变量节点转换为常量节点,以便所有内容都独立包含在.pb文件中。 Mores infos here for the freezing part 摩尔斯的相关信息在这里对冷冻部分

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM