[英]Tensorflow - Batch predict on multiple images
我有一個人faces
列表,其中列表的每個元素都是一個 numpy 數組,形狀為(1、224、224、3),即人臉圖像。 我有一個 model 的輸入形狀是(None, 224, 224, 3)
和 output 形狀是(None, 2)
。
現在我想對faces
列表中的所有圖像進行預測。 當然,我可以遍歷列表並逐個獲得預測,但我想將所有圖像作為一個批次處理,只使用一次調用model.predict()
來更快地獲得結果。
如果我像現在這樣直接傳遞面孔列表(最后的完整代碼),我只會得到第一張圖像的預測。
print(f"{len(faces)} faces found")
print(faces[0].shape)
maskPreds = model.predict(faces)
print(maskPreds)
Output:
3 faces found
(1, 224, 224, 3)
[[0.9421933 0.05780665]]
但是 3 張圖像的maskPreds
應該是這樣的:
[[0.9421933 0.05780665],
[0.01584494 0.98415506],
[0.09914105 0.9008589 ]]
完整代碼:
from tensorflow.keras.models import load_model
from cvlib import detect_face
import cv2
import numpy as np
def detectAllFaces(frame):
dets = detect_face(frame)
boxes = dets[0]
confidences = dets[1]
faces = []
for box, confidence in zip(boxes, confidences):
startX, startY, endX, endY = box
cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 1)
face = frame[startY:endY, startX:endX]
face = cv2.resize(face, (224, 224))
face = np.expand_dims(face, axis=0) # convert (224,224,3) to (1,224,224,3)
faces.append(face)
return faces, frame
model = load_model("mask_detector.model")
vs = cv2.VideoCapture(0)
model.summary()
while True:
ret, frame = vs.read()
if not ret:
break
faces, frame = detectAllFaces(frame)
if len(faces):
print(f"{len(faces)} faces found")
maskPreds = model.predict(faces) # <==========
print(maskPreds)
cv2.imshow("Window", frame)
if cv2.waitKey(1) == ord('q'):
break
cv2.destroyWindow("Window")
vs.release()
注意:如果我不將每個圖像從 (224, 224, 3) 轉換為 (1, 224, 224, 3),tensorflow 會拋出錯誤,指出輸入尺寸不匹配。
ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (224, 224, 3)
如何實現批量預測?
在這種情況下, model.predict()
function 的輸入需要作為 Z2EA9510C37F7F89E4941FF75F7F89E4941FF75F62F21CBZ 形狀數組(N,224,輸入圖像 24 數量)給出。
To achieve this, we can stack the N individual numpy arrays of size ( 1, 224, 224, 3) into one array of size ( N, 224, 224, 3) and then pass it to model.predict()
function.
maskPreds = model.predict(np.vstack(faces))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.