简体   繁体   中英

Java calling python function with tensorflow graph

So I have a neural network in tensorflow (python2.7) and I need to retrieve its output using Java. I have a simple python function getValue(input) which starts the session and retrieves the value. I am open to any suggestions. I believe Jython wont work cause tensorflow is not in the library. I need the call to be as fast as possible. JNI exists for Java calling C so can I convert with cython and compile then use JNI? Is there a way to pass the information in RAM or some other way I haven't thought of?

In Python, save the model (using saver.save) and the graph (using tf.train.write_graph).

In Java, use the org.bytedeco.javacpp-presets library to instantiate a GraphDef from the saved protobuf file and pass in your input features and get the output features within a Session.

See https://medium.com/google-cloud/how-to-invoke-a-trained-tensorflow-model-from-java-programs-27ed5f4f502d#.4su1s26fz for example code.

I've had the same problem, Java+Python+TensorFlow. I've ended up setting up a simple http server. If that's too slow for you, you can shave off some overhead by employing sockets directly.

Encapsulate your calling for TensorFlow into a script.py and then:

Process proc = Runtime.getRuntime().exec("python script.py");

Not sure whether it solves your case.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM