![](/img/trans.png)
[英]How to load and predict a pre-trained tensorflow model into Java code?
[英]Load pre-trained models in Tensorflow for Java
我正在尝试使用Java API在Tensorflow中加载预训练的模型 。
我注意到随着时间的推移,已保存的模型文件的格式已更改,现在已经保存了文件格式为.pb
, .ckpt
模型以及模型目录为model.ckpt.data-00000-of-00001 , model.ckpt.index
。
我正在遵循读取LabelImage示例中指定的模型的方法。 但是在此示例中,文件格式为protobuf .pb
。 我看到最新保存的模型以.ckpt
或model.ckpt.data-00000-of-00001 , model.ckpt.index
格式保存。
我试图用SavedModelBundle与方法export_dir
包含文件- model.ckpt.data-00000-of-00001
和model.ckpt.index
,但我得到这个错误
`2018-07-18 16:54:00.388790: I tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { }; Status: fail. Took 95 microseconds.
Exception in thread "main" org.tensorflow.TensorFlowException: SavedModel not found in export directory: /path/to/model_dir at org.tensorflow.SavedModelBundle.load(Native Method) at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:39)
有人可以告诉我我做错了什么,或者让我知道如何读取Java中.pb
以外的文件格式保存的已保存模型。
我认为有两种方法可以尝试解决问题:
将保存的模型还原到当前会话后:sess,
# Freeze the graph, with output _node_names is the name of the output when construct the model
# Eg. output_node_names = ["prediction"]
frozen_graph_def = tf.graph_util.convert_variables_to_constants (sess, sess.graph_def, output_node_names)
# Save the frozen graph
with open (frozen_graph_file, "wb") as f:
f.write(frozen_graph_def.SerializeToString()
它应该将以前的格式转换为新的格式。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.