[英]How to do the inference for a Tensorflow lite model with multiple inputs and outputs?
我創建了一個簡單的 tensorflow 分類 model,我將其轉換並導出為 a.tflite 文件。 對於 model 在我的 android 應用程序中的集成,我遵循了本教程,但它們僅涵蓋推理部分的單個輸入/輸出 model 類型。 查看文檔和其他一些資源后,我實施了以下解決方案:
// acc and gyro X, Y, Z are my features
float[] accX = new float[1];
float[] accY = new float[1];
float[] accZ = new float[1];
float[] gyroX = new float[1];
float[] gyroY = new float[1];
float[] gyroZ = new float[1];
Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
// And I have 4 classes
float[] output1 = new float[1];
float[] output2 = new float[1];
float[] output3 = new float[1];
float[] output4 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output1);
outputs.put(1, output2);
outputs.put(2, output3);
outputs.put(3, output4);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
但是這段代碼拋出異常:
java.lang.IllegalArgumentException:輸入張量索引無效:1
在這一步,我不確定出了什么問題。
這是我的模型的架構:
model = tf.keras.Sequential([
tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
解決方案:
根據@Karim Nosseir 的回答,我使用簽名方法訪問我的 model 的輸入和輸出。如果你有一個 model 內置於 python 那么你可以像答案中那樣找到簽名並使用它,如下所示:
Python 簽名:
{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}
Android java 使用:
float[] input = new float[6];
float[][] output = new float[1][4];
// Run decoding signature.
try (Interpreter interpreter = new Interpreter(loadModelFile())) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("dense_6_input", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("dense_8", output);
interpreter.runSignature(inputs, outputs, "serving_default");
} catch (IOException e) {
e.printStackTrace();
}
最簡單的方法是使用簽名 API 並將簽名名稱用於輸入/輸出
如果您使用 v2 TFLite Converter,您應該找到定義的簽名。
下面是打印定義了哪些簽名的示例
model = tf.keras.Sequential([
tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())
請參閱此處的指南,了解如何運行 Java 和其他語言。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.