簡體   English   中英

在對象檢測 API 上解釋 TF lite 的輸出

[英]Interpret output of TF lite on object detection API

我正在使用對象檢測 api 來訓練我的自定義數據以解決 2 類問題。 我正在使用 SSD Mobilenet v2。 我正在將模型轉換為 TF lite,並嘗試在 python 解釋器上執行它。 分數和班級的價值對我來說有些混亂,我無法為此做出有效的理由。 我得到以下分數值。

[[ 0.9998122 0.2795332 0.7827836 1.8154384 -1.1171713 0.152002 -0.90076405 1.6943774 -1.1098632 0.6275915 ]]

我得到以下類的值:

[[ 0. 1.742706 0.5762139 -0.23641224 -2.1639721 -0.6644413 -0.60925585 0.5485272 -0.9775026 1.4633082 ]]

對於例如-1.10986321.6943774我如何獲得大於 1 或小於 0 的1.6943774 此外,理想情況下,類應該是整數12因為它是一個2類對象檢測問題

我正在使用以下代碼



    import numpy as np
    import tensorflow as tf
    import cv2

    # Load TFLite model and allocate tensors.
    interpreter = tf.contrib.lite.Interpreter(model_path="C://Users//Admin//Downloads//tflitenew//detect.tflite")
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print(input_details)
    print(output_details)
    input_shape = input_details[0]['shape']
    print(input_shape)
    # change the following line to feed into your own data.
    #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

    input_data = cv2.imread("C:/Users/Admin/Pictures/fire2.jpg")
    #input_data = cv2.imread("C:/Users/Admin/Pictures/images4.jpg")
    #input_data = cv2.imread("C:\\Users\\Admin\\Downloads\\FlareModels\\lessimages\\video5_image_178.jpg")
    input_data = cv2.resize(input_data, (300, 300)) 

    input_data = np.expand_dims(input_data, axis=0)
    input_data = (2.0 / 255.0) * input_data - 1.0
    input_data=input_data.astype(np.float32)
    interpreter.reset_all_variables()
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data_scores = []
    output_data_scores = interpreter.get_tensor(output_details[2]['index'])
    print(output_data_scores)

    output_data_class = []
    output_data_class = interpreter.get_tensor(output_details[1]['index'])
    print(output_data_class)

看起來問題是由錯誤的輸入圖像通道順序引起的。 imread以“BGR”格式讀取圖像。 您可以嘗試添加

input_data = cv2.cvtColor(input_data,  cv2.COLOR_BGR2RGB)

獲取“RGB”格式的圖像,然后查看結果是否合理。

參考: 參考

tflite 模型的輸出需要后期處理。 默認情況下,模型返回固定數量(此處為 10 個檢測)。 使用索引 3 處的輸出張量來獲取有效框的數量num_det (即頂級num_det檢測是有效的,忽略其余的)

num_det = int(interpreter.get_tensor(output_details[3]['index']))
boxes = interpreter.get_tensor(output_details[0]['index'])[0][:num_det]
classes = interpreter.get_tensor(output_details[1]['index'])[0][:num_det]
scores = interpreter.get_tensor(output_details[2]['index'])[0][:num_det]

接下來,需要將框坐標縮放到圖像大小並進行調整,以便框在圖像內(某些可視化 API 需要這樣做)。

df = pd.DataFrame(boxes)
df['ymin'] = df[0].apply(lambda y: max(1,(y*img_height)))
df['xmin'] = df[1].apply(lambda x: max(1,(x*img_width)))
df['ymax'] = df[2].apply(lambda y: min(img_height,(y*img_height)))
df['xmax'] = df[3].apply(lambda x: min(img_width,(x * img_width)))
boxes_scaled = df[['ymin', 'xmin', 'ymax', 'xmax']].to_numpy()

這是帶有輸入預處理、輸出后處理和 mAP 評估的推理腳本的鏈接

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM