简体   繁体   中英

Error when using tf.data.Dataset.from_generator

I am trying to make tensorflow dataset using tensorflow from_generator, I am quite sure that I have made a python generator that work perfectly fine, but when I tried to pass it to from_generator I always got an error. this is the piece of code that I use to create the dataset

def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

import tensorflow as tf
ds_generator = dataset_generator(X_data, Y_data)
ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

but when I run it, it always produce error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-63-af75191f4a28> in <module>
      1 import tensorflow as tf
      2 ds_generator = dataset_generator(X_data, Y_data)
----> 3 ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)

~/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in from_generator(generator, output_types, output_shapes, args, output_signature)

TypeError: `generator` must be callable.

Hi the problem with your gen function is that you have to pass it as such via the args command, not as function as such

import tensorflow as tf
import numpy as np

# Gen Function
def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

# Created random data for testing
X_data = np.random.randn(100, 720, 720, 3).astype(np.float32)
Y_data = tf.one_hot(np.random.randint(0, 30, (100, )), 30)

# Testing function
ds = tf.data.Dataset.from_generator(
    dataset_generator,
    args=(X_data, Y_data), 
    output_types=(tf.float32, tf.uint8)
)

# Get output
next(iter(ds.batch(10).take(1)))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM