簡體   English   中英

使用TensorFlow對實時視頻進行分類

[英]Classify real-time video with TensorFlow

我正在使用本教程來開始使用TensorFlow- TensorFlow for poets

使用retrain.py腳本訓練模型后,我想使用retrained_graph.pb來對視頻進行分類,並在視頻運行時實時查看結果。

我所做的是使用opencv讀取要逐幀分類的視頻。 例如,讀取一個框架,將其保存,打開,分類,使用cv2.imshow()將其與分類結果一起顯示在屏幕上。

它可以工作,但是由於從/向磁盤讀取和寫入幀,因此產生的視頻延遲。

我可以使用從訓練過程中獲得的圖表來對視頻進行分類,而無需逐幀閱讀和保存它嗎?

這是我正在使用的代碼-

with tf.Session(graph=graph) as sess:

video_capture = cv2.VideoCapture(video_path)
i = 0
while True:
    frame = video_capture.read()[1] # get current frame
    frameId = video_capture.get(1) #current frame number
    i = i + 1
    cv2.imwrite(filename="C:\\video_images\\"+ str(i) +".jpg", img=frame) # write frame image to file
    image_data = "C:\\video_images\\" + str(i) + ".jpg"
    t = read_tensor_from_image_file(image_data,
                                    input_height=input_height,
                                    input_width=input_width,
                                    input_mean=input_mean,
                                    input_std=input_std)
    predictions = sess.run(output_operation.outputs[0], {input_operation.outputs[0]: t})
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    scores = []
    for node_id in top_k:
        human_string = label_lines[node_id]
        score = predictions[0][node_id]
        scores.append([score, human_string])
        #print('%s (score = %.5f)' % (human_string, score))
    #print("\n\n")
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(frame, scores[0][1] + " - " + repr(round(scores[0][0], 2)), (10, 50), font, 1, (0, 0, 255), 2, cv2.LINE_AA)
    cv2.putText(frame, scores[1][1] + " - " + repr(round(scores[1][0], 2)), (10, 100), font, 1, (0, 0, 255), 2, cv2.LINE_AA)
    cv2.imshow("image", frame)
    cv2.waitKey(1)
    os.remove("C:\\video_images\\" + str(i) + ".jpg")

video_capture.release()
cv2.destroyAllWindows()

謝謝。

frame = video_capture.read()[1] # get current frame
float_caster = frame.astype(np.float32)
dims_expander = np.expand_dims(float_caster, axis=0)
resized = cv2.resize(dims_expander,(int(input_width),int(input_height)))
normalized = (resized - input_mean) / input_std
predictions = sess.run(output_operation.outputs[0], {input_operation.outputs[0]: normalized})

獲取框架本身,而不是僅使用imwrite來調用read_tensor_from_image_file 調整大小並標准化。 然后,將normalized會話傳遞。 用這種方法擺脫不必要的磁盤寫/讀操作。

設法解決它。

將read_tensor_from_image_file編輯為以下內容,並僅使用框架而不是image_data進行輸入。

def read_tensor_from_image_file(file_name,
                            input_height=299,
                            input_width=299,
                            input_mean=0,
                            input_std=255):
input_name = "file_reader"
output_name = "normalized"

if type(file_name) is str:
    file_reader = tf.read_file(file_name, input_name)
    if file_name.endswith(".png"):
        image_reader = tf.image.decode_png(file_reader, channels = 3,
                                           name='png_reader')
    elif file_name.endswith(".gif"):
        image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                      name='gif_reader'))
    elif file_name.endswith(".bmp"):
        image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
    else:
        image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
                                            name='jpeg_reader')
    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0);
    resized = tf.image.resize_bilinear(dims_expander, [input_height,
                                                       input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), 
                           [input_std])
    sess = tf.Session()
    result = sess.run(normalized)

elif type(file_name) is np.ndarray:
    resized = cv2.resize(file_name, (input_width, input_height),
                            interpolation=cv2.INTER_LINEAR)
    normalized = (resized - input_mean) / input_std
    result = normalized
    result = array(result).reshape(1, 224, 224, 3)

return result

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM