简体   繁体   English

如何量化优化的 tflite 模型的输入和输出

[英]How to quantize inputs and outputs of optimized tflite model

I use the following code to generate a quantized tflite model我使用以下代码生成量化的 tflite 模型

import tensorflow as tf

def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    # Get sample input data as a numpy array in a method of your choosing.
    yield [input]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

But according to post training quantization :但根据训练后量化

The resulting model will be fully quantized but still take float input and output for convenience.生成的模型将被完全量化,但为了方便仍采用浮点输入和输出

To compile tflite model for Google Coral Edge TPU I need quantized input and output as well.要为Google Coral Edge TPU编译 tflite 模型,我还需要量化输入和输出。

In the model, I see that the first network layer converts float input to input_uint8 and the last layer converts output_uint8 to the float output.在模型中,我看到第一个网络层将浮点输入转换为input_uint8 ,最后一层将output_uint8转换为浮点输出。 How do I edit tflite model to get rid of the first and last float layers?如何编辑 tflite 模型以摆脱第一个和最后一个浮动图层?

I know that I could set input and output type to uint8 during conversion, but this is not compatible with any optimizations.我知道我可以在转换过程中将输入和输出类型设置为 uint8,但这与任何优化都不兼容。 The only available option then is to use fake quantization which results in a bad model.唯一可用的选择是使用导致错误模型的伪量化。

您可以通过设置 inference_input_type 和 inference_output_type ( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/lite.py#L460 -L476 ) 到 int8。

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir
converter.optimizations = [tf.lite.Optimize.DEFAULT] 
converter.representative_dataset = representative_dataset
#The below 3 lines performs the input - output quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM