簡體   English   中英

用 tflite 將 Mask-RCNN model 轉換為 android

[英]Convert Mask-RCNN model for android with tflite

我正在嘗試在 android 中部署這個 mask-rcnn model。 我能夠加載 keras 重量,凍結 model 並使用此腳本使用 tflite 1.13 toco 將其轉換為 .tflite model。

似乎這個 model 使用了一些 tflite 不支持的 tf_ops。 因此我不得不使用

converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]

轉換 model。 現在,當我嘗試使用 python 解釋器推斷這個 model 時,我在interpreter.invoke() 中遇到分段錯誤,並且 python 腳本崩潰。

def run_tf_model(model_path="mask_rcnn_coco.tflite"):
    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    print(" input_details", input_details)
    print("output_details",output_details)

    # Test model on random input data.
    input_shape = input_details['shape']
    print("input_shape tflite",input_shape)
    input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    interpreter.set_tensor(input_details['index'], input_data)
    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details['index'])
    print(output_data)

因此,我無法確定我轉換的 model 是否已正確轉換。

PS我打算在android中使用這個model,但是我對android(java或kotlin)tflite api的經驗很少。 如果有人可以指出任何用於學習的資源也會有所幫助。

編輯:我還嘗試使用 java api 在 android 上運行推理。 但是得到以下錯誤tensorflow/lite/kernels/gather.cc:80 0 <= axis && axis < NumDimensions(input). 在此 tensorflow 問題中有詳細說明

您可以使用 TFLite python 解釋器驗證您經過自定義訓練的 tflite model。 參考

import numpy as np
import tensorflow as tf

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM