[英]Parametrized generators while using tf.data.Dataset.from_generator()
我想为我的生成器提供参数以与tf.data.Dataset.from_generator()
结合使用。 例如:
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
这个生成器在lo
和hi
之间产生浮动。 但是请注意,在创建数据集时,这些参数永远不会传递给此生成器。
tf.data.Dataset.from_generator(generator, tf.float64)
这是因为tf.data.Dataset.from_generator()
的 generator 参数不应该带任何参数。
任何解决方案?
我找到了一个基于名为Partially Applied Functions的函数式编程概念的解决方案。 总之:
PAF 是一个函数,它接受一个具有多个参数的函数并返回一个具有较少参数的函数。
我这样做的方式如下:
from functools import partial
import tensorflow as tf
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
def get_generator(lo, hi):
return partial(generator, lo, hi)
tf.data.Dataset(get_generator(lo, hi), tf.float64)
get_generator(lo, hi)
函数返回生成器的部分应用函数,该函数修复了lo
和hi
参数的值,这实际上是tf.data.Dataset.from_generator()
所需的tf.data.Dataset.from_generator()
参数生成器。
TensorFlow Dataset
已经支持通过参数args
对生成器进行参数化,该参数只是传递给您的生成器( 请参阅文档)。 这是在 TensorFlow 2.0.0
上测试的最小工作示例。
import tensorflow as tf
x_train = [i for i in range(0, 20, 2)] # even
x_val = [i for i in range(1, 20, 2)] # odd
y_train = [i**2 for i in x_train] # squared
y_val = [i**2 for i in x_val]
def gen_data_epoch(test=False): # parametrized generator
train_data = x_val if test else x_train
label_data = y_val if test else y_train
n_tests = len(train_data)
for test_idx in range(len(train_data)):
yield train_data[test_idx], label_data[test_idx]
def get_dataset(test=False):
return tf.data.Dataset.from_generator(
gen_data_epoch, args=(test,),
output_types=(tf.int32, tf.int32))
print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.