簡體   English   中英

如何對具有多個輸入和輸出的 Tensorflow lite model 進行推斷?

[英]How to do the inference for a Tensorflow lite model with multiple inputs and outputs?

我創建了一個簡單的 tensorflow 分類 model,我將其轉換並導出為 a.tflite 文件。 對於 model 在我的 android 應用程序中的集成,我遵循了本教程,但它們僅涵蓋推理部分的單個輸入/輸出 model 類型。 查看文檔和其他一些資源后,我實施了以下解決方案:

        // acc and gyro X, Y, Z are my features
        float[] accX = new float[1];
        float[] accY = new float[1];
        float[] accZ = new float[1];

        float[] gyroX = new float[1];
        float[] gyroY = new float[1];
        float[] gyroZ = new float[1];


        Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
        
        // And I have 4 classes
        float[] output1 = new float[1];
        float[] output2 = new float[1];
        float[] output3 = new float[1];
        float[] output4 = new float[1];

        Map<Integer, Object> outputs = new HashMap<>();
        outputs.put(0, output1);
        outputs.put(1, output2);
        outputs.put(2, output3);
        outputs.put(3, output4);

        interpreter.runForMultipleInputsOutputs(inputs, outputs);

但是這段代碼拋出異常:

java.lang.IllegalArgumentException:輸入張量索引無效:1

在這一步,我不確定出了什么問題。

這是我的模型的架構:

 model = tf.keras.Sequential([
        tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
        tf.keras.layers.Dense(240, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])

解決方案:

根據@Karim Nosseir 的回答,我使用簽名方法訪問我的 model 的輸入和輸出。如果你有一個 model 內置於 python 那么你可以像答案中那樣找到簽名並使用它,如下所示:

Python 簽名:

{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}

Android java 使用:

        float[] input = new float[6];
        float[][] output = new float[1][4];
        
        // Run decoding signature.
        try (Interpreter interpreter = new Interpreter(loadModelFile())) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("dense_6_input", input);

            Map<String, Object> outputs = new HashMap<>();
            outputs.put("dense_8", output);

            interpreter.runSignature(inputs, outputs, "serving_default");
        } catch (IOException e) {
            e.printStackTrace();
        }

最簡單的方法是使用簽名 API 並將簽名名稱用於輸入/輸出

如果您使用 v2 TFLite Converter,您應該找到定義的簽名。

下面是打印定義了哪些簽名的示例

model = tf.keras.Sequential([
        tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
        tf.keras.layers.Dense(240, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())

請參閱此處的指南,了解如何運行 Java 和其他語言。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM