简体   繁体   中英

Parametrized generators while using tf.data.Dataset.from_generator()

I would like to give parameters to my generator to use in combination with tf.data.Dataset.from_generator() . For example:

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

This generator yields floats between lo and hi . Notice however than when creating a Dataset, these parameters are never passed to this generator.

tf.data.Dataset.from_generator(generator, tf.float64)

This is because the generator parameter of tf.data.Dataset.from_generator() should take no arguments.

Any solutions?

I found a solution based on a functional programming concept called Partially Applied Functions . In summary:

a PAF is a function that takes a function with multiple parameters and returns a function with fewer parameters.

The way I did it is the following:

from functools import partial
import tensorflow as tf

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

def get_generator(lo, hi):
    return partial(generator, lo, hi)

tf.data.Dataset(get_generator(lo, hi), tf.float64)

The get_generator(lo, hi) function returns a partially applied function for the generator which fixes the values for the lo and hi parameters, which is in fact the parameterless generator required by tf.data.Dataset.from_generator() .

TensorFlow Dataset already supports parametrizing the generator through the argument args which is simply passed to your generator ( see docs ). Here is a minimal working example tested on TensorFlow 2.0.0 .

import tensorflow as tf

x_train = [i for i in range(0, 20, 2)]  # even
x_val = [i for i in range(1, 20, 2)]  # odd
y_train = [i**2 for i in x_train]  # squared
y_val = [i**2 for i in x_val]

def gen_data_epoch(test=False):  # parametrized generator
    train_data = x_val if test else x_train
    label_data = y_val if test else y_train
    n_tests = len(train_data)
    for test_idx in range(len(train_data)):
        yield train_data[test_idx], label_data[test_idx]

def get_dataset(test=False):
    return tf.data.Dataset.from_generator(
        gen_data_epoch, args=(test,),
        output_types=(tf.int32, tf.int32))

print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM