简体   繁体   中英

Use Variables from frozen tensorflow graph on Android

TLDR: How to use Variables from frozen tensorflow graphs on Android?


1. What I want to do

I have a Tensorflow model that keeps an internal state in multiple variables, created with: state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False) .

This state is modified during inference:

tf.assign(state_var, new_value)

I now want to deploy the model on Android. I was able to make the Tensorflow example App run. There, a frozen model is loaded, which works fine.


2. Restoring variables from frozen graph does not work

However, when you freeze a graph using the freeze_graph script , all Variables are converted to constants. This is fine for weights of the network, but not for the internal state. The inference fails with the following message. I interpret this as "assign does not work on constant tensors"

java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.

Luckily, you can blacklist Variables from being converted to constants. However, this also doesn't work because the frozen graph now contains uninitialized variables.

java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state

3. Restoring SavedModel does not work on Android

One last version I have tried is to use the SavedModel format which should contain both, a frozen graph and the variables. Unfortunately, calling the restore method does not work on Android.

SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);

// produces error:

E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
     java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)

4. How can I make this work?

I don't know what else I can try. Here's what I would imagine, but I don't know how to make it work:

  1. Figure out a way to initialize variables on Android
  2. Figure out a different way to freeze the model, so that maybe the initializer op is also part of the frozen graph and can be run from Android
  3. Find out if/how RNNs/LSTMs are implemented internally, because these should also have the same requirement of using variables during inference (and I assume LSTMs to be able to be deployed on Android).
  4. ???

I have solved this myself by going down a different route. To the best of my knowledge, the "variable" concept cannot be used in the same way on Android as I was used to in Python (eg you cannot initialize variables and then have an internal state of the network be updated during inference).

Instead, you can use placehlder and output nodes to preserve the state inside your Java code and feed it to the network on every inference call.

  • replace all tf.Variable occurances with tf.placeholder . The shape stays the same.
  • I also defined an additional node used to read the output. (Maybe you can simply read the placeholder itself, I haven't tried that.) tf.identity(inputs, name='state_output')
  • During inference on Android, you then feed the initial state into the network.

    float[] values = {0, 0, 0, ...}; // zeros of the correct shape inferenceInterface.feed('state', values, ...);

  • After inference, you read the resulting internal state of the network

    float[] values = new float[output_shape]; inferenceInterface.fetch('state_output', values);

    You then remember this output in Java to pass it into the 'state' placeholder for the next inference call.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM