简体   繁体   English

将 Mobilenet Model 转换为 TFLite 会更改输入大小

[英]Converting Mobilenet Model to TFLite changes input size

right now I'm trying to convert a SavedModel to TFLite for use on a raspberry pi.现在我正在尝试将 SavedModel 转换为 TFLite 以在树莓派上使用。 The model is MobileNet Object Detection trained on a custom dataset. model 是在自定义数据集上训练的 MobileNet Object 检测。 The SavedModel works perfectly, and retains the same shape of (1, 150, 150, 3) . SavedModel 完美运行,并保持(1, 150, 150, 3)相同的形状。 However, when I convert it to a TFLite model using this code:但是,当我使用以下代码将其转换为 TFLite model 时:

import tensorflow as tf

saved_model_dir = input("Model dir: ")

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

And run this code to run the interpreter:并运行此代码来运行解释器:

import numpy as np
import tensorflow as tf
from PIL import Image

from os import listdir
from os.path import isfile, join

from random import choice, random

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


input_shape = input_details[0]['shape']
print(f"Required input shape: {input_shape}")

I get an input shape of [1 1 1 3] , therefore I can't use a 150x150 image as input.我得到[1 1 1 3]的输入形状,因此我不能使用 150x150 图像作为输入。

I'm using Tensorflow 2.4 on Python 3.7.10 with Windows 10.我在 Python 3.7.10 上使用 Tensorflow 2.4 和 Windows 10。

How would I fix this?我将如何解决这个问题?

How about calling resize_tensor_input() before calling allocate_tensors()?在调用 allocate_tensors() 之前调用 resize_tensor_input() 怎么样?

interpreter.resize_tensor_input(0, [1, 150, 150, 3], strict=True)
interpreter.allocate_tensors()

You can rely on TFLite converter V1 API to set input shapes.您可以依靠 TFLite 转换器 V1 API 来设置输入形状。 Please check out the input_shapes argument in https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter .请查看https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter中的 input_shapes 参数。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM