繁体   English   中英

有没有办法在 tf.data 管道中使用 tf.keras.model.predict ?

[英]Is there a way to use tf.keras.model.predict within a tf.data pipeline?

我有一个训练有素的 model,我想在tf.data管道中使用第二个 model。 当我尝试这样做时,我得到一个ValueError: Unknown graph. Aborting. ValueError: Unknown graph. Aborting. 我不知道该怎么处理这个错误信息。

我的代码看起来像这样:

def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list)
    images = files.map(load_image_from_file) 

    def pass_image_through_model(img):
        return model.predict(img, steps=1)

    dataset = images.map(pass_image_through_model)
    return dataset

这有什么问题? 我得到的错误是:

    /home/.../code/dataloader.py:236 pass_image_through_model  *
        return model.predict(img, steps=1)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
        use_multiprocessing=use_multiprocessing)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
        callbacks=callbacks)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
        f = _make_execution_function(model, mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
        return model._make_execution_function(mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
        self._make_predict_function()
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
        **kwargs)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
        return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
        raise ValueError('Unknown graph. Aborting.')

    ValueError: Unknown graph. Aborting.

解决此问题的最简单方法之一是将输入直接传递给 model,而不是使用model.predit方法。 原因是model.predict返回numpy.ndarray 这会导致错误,因为tf.data使用图执行,这意味着最好在该图中有任何操作输入和 output 一个张量。

下面是一个快速工作的例子。

import tensorflow as tf

# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)

def map_fn(row):
    return model(row)


# Create some input data 
a = tf.constant([1, 2])

# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))

for el in model_mapped_ds:
    print(el)

最后,下面是您使用时的样子。


def pass_image_through_model(img):
    return model(img) # this returns a tensor 

@tf.function
def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
    images = files.map(load_image_from_file) 

    dataset = images.map(pass_image_through_model)
    return dataset

如果这是您第一次处理tf.data.Dataset() object,您得到的错误可能是无声的。

tf.data.Dataset()中的所有操作实际上都是在图形模式下执行的,您不能使用tf.*中预定义的功能之外的任何功能。

将任意 Python 代码与tf.data.Dataset()混合的唯一方法是使用tf.py_function() ,否则将引发错误。

请记住,将 Python 代码与优化tf.data.Dataset()代码混合会导致时间性能下降。

测试的唯一方法是检索您的数据集,使用as_numpy_iterator()来获取您的数据并使用您的 model 进行预测,因此在映射过程之外。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM