简体   繁体   English

Tensorflow模型导入Java

[英]Tensorflow model import to Java

I have been trying to import and make use of my trained model (Tensorflow, Python) in Java. 我一直在尝试导入并使用Java中训练有素的模型(Tensorflow,Python)。

I was able to save the model in Python, but encountered problems when I try to make predictions using the same model in Java. 我能够在Python中保存模型,但是当我尝试使用Java中的相同模型进行预测时遇到了问题。

Here , you can see the python code for initializing, training, saving the model. 在这里 ,您可以看到用于初始化,训练,保存模型的python代码。

Here , you can see the Java code for importing and making predictions for input values. 在这里 ,您可以看到用于导入和预测输入值的Java代码。

The error message I get is: 我得到的错误信息是:

Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
     [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:285)
    at org.tensorflow.Session$Runner.run(Session.java:235)
    at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

I believe, the problem is somewhere in the python code, but I was not able to find it. 我相信,问题出在python代码中,但我无法找到它。

The Java importGraphDef() function is only importing the computational graph (written by tf.train.write_graph in your Python code), it isn't loading the values of trained variables (stored in the checkpoint), which is why you get an error complaining about uninitialized variables. Java importGraphDef()函数只导入计算图(由Python代码中的tf.train.write_graph编写),它不加载训练变量的值(存储在检查点中),这就是为什么你会得到一个错误抱怨未初始化的变量。

The TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load to create session initialized with the trained variable values. 另一方面, TensorFlow SavedModel格式包括有关模型的所有信息(图形,检查点状态,其他元数据),并且在Java中使用,您希望使用SavedModelBundle.load来创建使用训练变量值初始化的会话。

To export a model in this format from Python, you might want to take a look at a related question Deploy retrained inception SavedModel to google cloud ml engine 要从Python导出这种格式的模型,您可能需要查看相关问题将重新训练开始的SavedModel部署到Google Cloud ml引擎

In your case, this should amount to something like the following in Python: 在您的情况下,这应该类似于Python中的以下内容:

def save_model(session, input_tensor, output_tensor):
  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
  )
  b = saved_model_builder.SavedModelBuilder('/tmp/model')
  b.add_meta_graph_and_variables(session,
                                 [tf.saved_model.tag_constants.SERVING],
                                 signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
  b.save() 

And invoke that via save_model(session, x, yhat) 并通过save_model(session, x, yhat)调用它save_model(session, x, yhat)

And then in Java load the model using: 然后在Java中加载模型使用:

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
  // b.session().run(...)
}

Hope that helps. 希望有所帮助。

Fwiw, Deeplearning4j lets you import models trained on TensorFlow with Keras 1.0 (Keras 2.0 support is on the way). Fwiw,Deeplearning4j允许您使用Keras 1.0导入在TensorFlow上训练的模型(Keras 2.0支持即将推出)。

https://deeplearning4j.org/model-import-keras https://deeplearning4j.org/model-import-keras

We also built a library called Jumpy, which is a wrapper around Numpy arrays and Pyjnius that uses pointers instead of copying data, which makes it more efficient than Py4j when dealing with tensors. 我们还构建了一个名为Jumpy的库,它是Numpy数组和Pyjnius的包装器,它使用指针而不是复制数据,这使得它在处理张量时比Py4j更有效。

https://deeplearning4j.org/jumpy https://deeplearning4j.org/jumpy

Your python-model will certainly fail at this: 你的python模型肯定会失败:

sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))

#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)

init is not defined in the model - I'm unsure what you want achieve at this place, but that should give you a starting point init没有在模型中定义 - 我不确定你想在这个地方实现什么,但这应该给你一个起点

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM