簡體   English   中英

捕獲從 TensorFlow 拋出的 python 異常

[英]Catch python exception thrown from TensorFlow

I'm running a python (v 3.6.5) code that is using TensorFlow (v 1.13.2) to perform inference using a trained model (on Windows 8.1).

我想捕獲(並記錄)從 TensorFlow 庫內部拋出的異常/錯誤。

例如,當批處理大小(在 session.run() 期間)太大時,進程使用所有系統 memory 並崩潰。

我的代碼如下所示:

import tensorflow as tf
import math
from tqdm import tqdm
# …

def parse_function(image_string, frame_id):
    image = tf.image.decode_jpeg(image_string, channels=3)
    resize_image = tf.image.resize_images(image, [224, 224], method=tf.image.ResizeMethod.BICUBIC)
    return resize_image, frame_id


def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph


def main(_):
    batch_size = 128

    num_frames = 5000
    num_batches = int(np.ceil(num_frames / batch_size))
    frame_ids = get_ids()

    with MyFrameReader() as frd:
        im_list = []
        for id in frame_ids:
            im_list.append(frd.get_frame(id))

    dataset = tf.data.Dataset.from_tensor_slices((im_list, frame_ids))
    dataset = dataset.map(parse_function)
    batched_dataset = dataset.batch(batch_size)
    iterator = batched_dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    graph = load_graph(PB_FILE)
    x = graph.get_tensor_by_name('prefix/input_image:0')
    y = graph.get_tensor_by_name('prefix/output_node:0')
    sess1 = tf.Session(graph=graph)
    sess2 = tf.Session(config= tf.ConfigProto(device_count={'GPU': 0})) # Run on CPU
    sess2.run(iterator.initializer)

    for _ in tqdm(range(num_batches)):
        try:
            # pre process
            inference_batch, frame_id_batch = sess2.run(next_element)
            # main process
            scores_np = sess1.run(y, feed_dict={x: inference_batch})
            # post process …
        except MemoryError as e:
            print('Error 1')
        except Exception as e:
            print('Error 2')
        except tf.errors.OpError as e:
            print('Error 3')
        except:
            print('Error 4')
    sess1.close()
    sess2.close()

我看到進程的 memory 增長並且在某些時候它在沒有到達異常處理代碼的情況下死亡。 (如果我在 python 中添加韭菜 memory 的代碼,我設法捕捉到 memory 異常)

有人可以解釋發生了什么嗎?

這應該是由捕獲不正確的異常引起的。 Tensorflow 定義了自己的異常,它們是 Exception 的子類( 源代碼

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM