简体   繁体   中英

Description of TF Lite's Toco converter args for quantization aware training

These days I am trying to track down an error concerning the deployment of a TF model with TPU support.

I can get a model without TPU support running, but as soon as I enable quantization, I get lost.

I am in the following situation:

  1. Created a model and trained it
  2. Created an eval graph of the model
  3. Froze the model and saved the result as protocol buffer
  4. Successfully converted and deployed it without TPU support

For the last point, I used the TFLiteConverter's Python API. The script that produces a functional tflite model is

import tensorflow as tf

graph_def_file = 'frozen_model.pb'
inputs = ['dense_input']
outputs = ['dense/BiasAdd']

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, inputs, outputs)
converter.inference_type = tf.lite.constants.FLOAT
input_arrays = converter.get_input_arrays()

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

tflite_model = converter.convert()

open('model.tflite', 'wb').write(tflite_model)

This tells me that my approach seems to be ok up to this point. Now, if I want to utilize the Coral TPU stick, I have to quantize my model (I took that into account during training). All I have to do is to modify my converter script. I figured that I have to change it to

import tensorflow as tf

graph_def_file = 'frozen_model.pb'
inputs = ['dense_input']
outputs = ['dense/BiasAdd']

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, inputs, outputs)
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8      ## Indicates TPU compatibility
input_arrays = converter.get_input_arrays()

converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}     ## mean, std_dev
converter.default_ranges_stats = (-128, 127)                      ## min, max values for quantization (?)
converter.allow_custom_ops = True                                 ## not sure if this is needed

## REMOVED THE OPTIMIZATIONS ALTOGETHER TO MAKE IT WORK

tflite_model = converter.convert()

open('model.tflite', 'wb').write(tflite_model)

This tflite model produces results when loaded with the Python API of the interpreter, but I am not able to understand their meaning. Also, there is no (or if there is, it is hidden well) documentation on how to choose mean, std_dev and the min/max ranges. Also, after compiling this with the edgetpu_compiler and deploying it (loading it with the C++ API), I receive an error:

INFO: Initialized TensorFlow Lite runtime.
ERROR: Failed to prepare for TPU. generic::failed_precondition: Custom op already assigned to a different TPU.
ERROR: Node number 0 (edgetpu-custom-op) failed to prepare.

Segmentation fault

I suppose I missed a flag or something during the conversion process. But as the documentation is also lacking here, I can't say for sure.

In short:

  1. What do the params mean, std_dev, min/max do and how do they interact?
  2. What am I doing wrong during the conversion?

I am grateful for any help or guidance!

EDIT: I have opened a github issue with the full test code. Feel free to play around with this.

You should never need to manually set the quantization stats.

Have you tried the post-training-quantization tutorials?

https://www.tensorflow.org/lite/performance/post_training_integer_quant

Basically they set the quantization options:

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

Then they pass a "representative dataset" to the converter, so that the converter can run the model a few batches to gather the necessary statistics:

def representative_data_gen():
  for input_value in mnist_ds.take(100):
    yield [input_value]

converter.representative_dataset = representative_data_gen

While there are options for quantized training, it's always easier to to do post-training quantization.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM