簡體   English   中英

在 Tensorflow 對象檢測 API 中打印類名和分數

[英]Printing class name and score in Tensorflow Object Detection API

我正在使用 Tensorflow 對象檢測 API 一切正常,但我想打印具有以下格式 {Object name , Score} 或類似內容的字典或數組,我只需要對象名稱和分數。

我嘗試使用以下代碼:

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    for image_path in TEST_IMAGE_PATHS:
      image = Image.open(image_path)
      # the array based representation of the image will be used later in order to prepare the
      # result image with boxes and labels on it.
      image_np = load_image_into_numpy_array(image)
      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
      image_np_expanded = np.expand_dims(image_np, axis=0)
      # Actual detection.
      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
      print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5])

      threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects
      print(len(np.where(scores[0] > threshold)[0])/num_detections[0])

這部分正在工作

  print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5])

它正在打印 [{'name': 'computer', 'id': 1}] 他們有什么辦法可以將該對象的分數添加到 dict 中嗎?

我在他們使用的 Stackoverflow 上看到了另一個問題:

 threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects
 print(len(np.where(scores[0] > threshold)[0])/num_detections[0])

這給了我Tensor("truediv:0", dtype=float32)但即使它有效也不夠,因為我沒有對象的名稱。

謝謝

所以這是對我有用的解決方案。 (如果您仍在尋找解決方案,那就是)

# The following code replaces the 'print ([category_index...' statement
objects = []
for index, value in enumerate(classes[0]):
  object_dict = {}
  if scores[0, index] > threshold:
    object_dict[(category_index.get(value)).get('name').encode('utf8')] = \
                        scores[0, index]
    objects.append(object_dict)
print objects

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM