簡體   English   中英

使用 TF 2.0 將 saved_model 轉換為 TFLite 模型

[英]Converting saved_model to TFLite model using TF 2.0

目前我正在將自定義對象檢測模型(使用 SSD 和初始網絡訓練)轉換為量化的 TFLite 模型。 我可以使用以下代碼片段(使用Tensorflow 1.4 )將自定義對象檢測模型從凍結圖轉換為量化的 TFLite 模型:

converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True
converter.post_training_quantize=True 
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)

但是tf.lite.TFLiteConverter.from_frozen_graph類方法不適用於Tensorflow 2.0請參閱此鏈接)。 所以我嘗試使用tf.lite.TFLiteConverter.from_saved_model類方法轉換模型。 代碼片段如下所示:

converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

上面的代碼片段拋出以下錯誤:

ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.

我試圖將input_shapes作為參數傳遞

converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor" : [1,300,300,3]})

但它會引發以下錯誤:

TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'

我錯過了什么嗎? 請隨時糾正我!

我使用tf.compat.v1.lite.TFLiteConverter.from_frozen_graph得到了解決方案。 compat.v1帶來的功能TF1.xTF2.x 以下是完整代碼:

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
    input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
    'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True

# Convert the model to quantized TFLite model.
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()


# Write a model using the following line
open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM