繁体   English   中英

使用 tf.data.Dataset.from_generator 时出错

[英]Error when using tf.data.Dataset.from_generator

我正在尝试使用 tensorflow from_generator 制作 tensorflow 数据集,我很确定我已经制作了一个运行良好的 python 生成器,但是当我尝试将其传递给 from_generator 时总是出错。 这是我用来创建数据集的一段代码

def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

import tensorflow as tf
ds_generator = dataset_generator(X_data, Y_data)
ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

但是当我运行它时,它总是会产生错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-63-af75191f4a28> in <module>
      1 import tensorflow as tf
      2 ds_generator = dataset_generator(X_data, Y_data)
----> 3 ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)

~/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in from_generator(generator, output_types, output_shapes, args, output_signature)

TypeError: `generator` must be callable.

嗨,您的 gen function 的问题是您必须通过 args 命令来传递它,而不是像 function 这样

import tensorflow as tf
import numpy as np

# Gen Function
def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

# Created random data for testing
X_data = np.random.randn(100, 720, 720, 3).astype(np.float32)
Y_data = tf.one_hot(np.random.randint(0, 30, (100, )), 30)

# Testing function
ds = tf.data.Dataset.from_generator(
    dataset_generator,
    args=(X_data, Y_data), 
    output_types=(tf.float32, tf.uint8)
)

# Get output
next(iter(ds.batch(10).take(1)))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM