I would like to give parameters to my generator to use in combination with tf.data.Dataset.from_generator()
. For example:
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
This generator yields floats between lo
and hi
. Notice however than when creating a Dataset, these parameters are never passed to this generator.
tf.data.Dataset.from_generator(generator, tf.float64)
This is because the generator parameter of tf.data.Dataset.from_generator()
should take no arguments.
Any solutions?
I found a solution based on a functional programming concept called Partially Applied Functions . In summary:
a PAF is a function that takes a function with multiple parameters and returns a function with fewer parameters.
The way I did it is the following:
from functools import partial
import tensorflow as tf
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
def get_generator(lo, hi):
return partial(generator, lo, hi)
tf.data.Dataset(get_generator(lo, hi), tf.float64)
The get_generator(lo, hi)
function returns a partially applied function for the generator which fixes the values for the lo
and hi
parameters, which is in fact the parameterless generator required by tf.data.Dataset.from_generator()
.
TensorFlow Dataset
already supports parametrizing the generator through the argument args
which is simply passed to your generator ( see docs ). Here is a minimal working example tested on TensorFlow 2.0.0
.
import tensorflow as tf
x_train = [i for i in range(0, 20, 2)] # even
x_val = [i for i in range(1, 20, 2)] # odd
y_train = [i**2 for i in x_train] # squared
y_val = [i**2 for i in x_val]
def gen_data_epoch(test=False): # parametrized generator
train_data = x_val if test else x_train
label_data = y_val if test else y_train
n_tests = len(train_data)
for test_idx in range(len(train_data)):
yield train_data[test_idx], label_data[test_idx]
def get_dataset(test=False):
return tf.data.Dataset.from_generator(
gen_data_epoch, args=(test,),
output_types=(tf.int32, tf.int32))
print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.