简体   繁体   English

来自生成器 OutOfRangeError 的 Tensorflow 数据集:序列结束

[英]Tensorflow dataset from generator OutOfRangeError: End of sequence

Function that returns a generator: Function 返回一个生成器:

def img_x_gen(dir):
    files = glob.glob(f'{dir}/*.jpg')
    for file in files:
        X_i = np.asarray(Image.open(file))
        X_i = X_i / 255.0
        yield X_i, X_i  # It's an autoencoder so return X_i twice

Creating the datasets:创建数据集:

types = (tf.float32, tf.float32)
shapes = (img_shape, img_shape)
ds_train = tf.data.Dataset.from_generator(img_x_gen, types, shapes,
    args=['train/img_sq']).batch(batch_size)
ds_valid = tf.data.Dataset.from_generator(img_x_gen, types, shapes,
    args=['valid/img_sq'],).batch(batch_size)

Calling fit method:调用 fit 方法:

vae.fit(ds_train, epochs=3, validation_data=ds_valid, verbose=True) 

I get the error in the question title:我在问题标题中得到错误:

OutOfRangeError:  End of sequence

The number of examples in the training set is 894, in the validation set it's 247. batch_size is 32. I know the model works if I load data into memory.训练集中的示例数为 894,验证集中为batch_size为 32。我知道如果我将数据加载到 memory 中,model 可以工作。

I've also tried making a generator and manually batching (and passing steps_per_epoch and validation_steps to the model.fit method), but that runs into a similar error: Your input ran out of data; interrupting training我也尝试过制作一个生成器并手动批处理(并将steps_per_epochvalidation_steps传递给model.fit方法),但这会遇到类似的错误: Your input ran out of data; interrupting training Your input ran out of data; interrupting training . Your input ran out of data; interrupting training

So clearly I don't understand something about generators.很明显,我对发电机一无所知。

You can directly use model.fit_generator() with generators instead of model.fit() .您可以直接将model.fit_generator()与生成器一起使用,而不是model.fit() You are receiving this error because your generator does not yield require number of values as per shape of your input.您收到此错误是因为您的生成器不会根据输入的形状产生所需的值数量。 You can quick fix it by making it an infinite generator.您可以通过使其成为无限生成器来快速修复它。

def img_x_gen(dir):
    while True:
        # Make your generator infinite
        files = glob.glob(f'{dir}/*.jpg')
        for file in files:
            X_i = np.asarray(Image.open(file))
            X_i = X_i / 255.0
            yield X_i, X_i 

Because your dataset is iterated to the end, use.repeat() function to repeat your dataset:因为您的数据集迭代到最后,所以使用.repeat() function 重复您的数据集:

dataset = dataset.repeat()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM