[英]Tensorflow model import to Java
I have been trying to import and make use of my trained model (Tensorflow, Python) in Java. 我一直在尝试导入并使用Java中训练有素的模型(Tensorflow,Python)。
I was able to save the model in Python, but encountered problems when I try to make predictions using the same model in Java. 我能够在Python中保存模型,但是当我尝试使用Java中的相同模型进行预测时遇到了问题。
Here , you can see the python code for initializing, training, saving the model. 在这里 ,您可以看到用于初始化,训练,保存模型的python代码。
Here , you can see the Java code for importing and making predictions for input values. 在这里 ,您可以看到用于导入和预测输入值的Java代码。
The error message I get is: 我得到的错误信息是:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
I believe, the problem is somewhere in the python code, but I was not able to find it. 我相信,问题出在python代码中,但我无法找到它。
The Java importGraphDef()
function is only importing the computational graph (written by tf.train.write_graph
in your Python code), it isn't loading the values of trained variables (stored in the checkpoint), which is why you get an error complaining about uninitialized variables. Java
importGraphDef()
函数只导入计算图(由Python代码中的tf.train.write_graph
编写),它不加载训练变量的值(存储在检查点中),这就是为什么你会得到一个错误抱怨未初始化的变量。
The TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load
to create session initialized with the trained variable values. 另一方面, TensorFlow SavedModel格式包括有关模型的所有信息(图形,检查点状态,其他元数据),并且在Java中使用,您希望使用
SavedModelBundle.load
来创建使用训练变量值初始化的会话。
To export a model in this format from Python, you might want to take a look at a related question Deploy retrained inception SavedModel to google cloud ml engine 要从Python导出这种格式的模型,您可能需要查看相关问题将重新训练开始的SavedModel部署到Google Cloud ml引擎
In your case, this should amount to something like the following in Python: 在您的情况下,这应该类似于Python中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
And invoke that via save_model(session, x, yhat)
并通过
save_model(session, x, yhat)
调用它save_model(session, x, yhat)
And then in Java load the model using: 然后在Java中加载模型使用:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
Hope that helps. 希望有所帮助。
Fwiw, Deeplearning4j lets you import models trained on TensorFlow with Keras 1.0 (Keras 2.0 support is on the way). Fwiw,Deeplearning4j允许您使用Keras 1.0导入在TensorFlow上训练的模型(Keras 2.0支持即将推出)。
https://deeplearning4j.org/model-import-keras https://deeplearning4j.org/model-import-keras
We also built a library called Jumpy, which is a wrapper around Numpy arrays and Pyjnius that uses pointers instead of copying data, which makes it more efficient than Py4j when dealing with tensors. 我们还构建了一个名为Jumpy的库,它是Numpy数组和Pyjnius的包装器,它使用指针而不是复制数据,这使得它在处理张量时比Py4j更有效。
https://deeplearning4j.org/jumpy https://deeplearning4j.org/jumpy
Your python-model will certainly fail at this: 你的python模型肯定会失败:
sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))
#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)
init
is not defined in the model - I'm unsure what you want achieve at this place, but that should give you a starting point init
没有在模型中定义 - 我不确定你想在这个地方实现什么,但这应该给你一个起点
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.