简体   繁体   English

捕获从 TensorFlow 抛出的 python 异常

[英]Catch python exception thrown from TensorFlow

I'm running a python (v 3.6.5) code that is using TensorFlow (v 1.13.2) to perform inference using a trained model (on Windows 8.1). I'm running a python (v 3.6.5) code that is using TensorFlow (v 1.13.2) to perform inference using a trained model (on Windows 8.1).

I want to catch (and log) exceptions/errors that are thrown from inside TensorFlow library.我想捕获(并记录)从 TensorFlow 库内部抛出的异常/错误。

For example when the batch size (during a session.run()) is too large the process use all system memory and crash.例如,当批处理大小(在 session.run() 期间)太大时,进程使用所有系统 memory 并崩溃。

My code looks like this:我的代码如下所示:

import tensorflow as tf
import math
from tqdm import tqdm
# …

def parse_function(image_string, frame_id):
    image = tf.image.decode_jpeg(image_string, channels=3)
    resize_image = tf.image.resize_images(image, [224, 224], method=tf.image.ResizeMethod.BICUBIC)
    return resize_image, frame_id


def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph


def main(_):
    batch_size = 128

    num_frames = 5000
    num_batches = int(np.ceil(num_frames / batch_size))
    frame_ids = get_ids()

    with MyFrameReader() as frd:
        im_list = []
        for id in frame_ids:
            im_list.append(frd.get_frame(id))

    dataset = tf.data.Dataset.from_tensor_slices((im_list, frame_ids))
    dataset = dataset.map(parse_function)
    batched_dataset = dataset.batch(batch_size)
    iterator = batched_dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    graph = load_graph(PB_FILE)
    x = graph.get_tensor_by_name('prefix/input_image:0')
    y = graph.get_tensor_by_name('prefix/output_node:0')
    sess1 = tf.Session(graph=graph)
    sess2 = tf.Session(config= tf.ConfigProto(device_count={'GPU': 0})) # Run on CPU
    sess2.run(iterator.initializer)

    for _ in tqdm(range(num_batches)):
        try:
            # pre process
            inference_batch, frame_id_batch = sess2.run(next_element)
            # main process
            scores_np = sess1.run(y, feed_dict={x: inference_batch})
            # post process …
        except MemoryError as e:
            print('Error 1')
        except Exception as e:
            print('Error 2')
        except tf.errors.OpError as e:
            print('Error 3')
        except:
            print('Error 4')
    sess1.close()
    sess2.close()

I see that memory of the process grows and at some point it dies without reaching the exception handling code.我看到进程的 memory 增长并且在某些时候它在没有到达异常处理代码的情况下死亡。 (if I add code in python that leeks memory I manage to catch a memory exception) (如果我在 python 中添加韭菜 memory 的代码,我设法捕捉到 memory 异常)

Can someone please explain what is going on?有人可以解释发生了什么吗?

This should be caused by catching incorrect exceptions.这应该是由捕获不正确的异常引起的。 Tensorflow defines its own exceptions which are subclasses of Exception ( source code ) Tensorflow 定义了自己的异常,它们是 Exception 的子类( 源代码

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM