繁体   English   中英

使用 tf.data.Dataset.from_generator() 时的参数化生成器

[英]Parametrized generators while using tf.data.Dataset.from_generator()

我想为我的生成器提供参数以与tf.data.Dataset.from_generator()结合使用。 例如:

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

这个生成器在lohi之间产生浮动。 但是请注意,在创建数据集时,这些参数永远不会传递给此生成器。

tf.data.Dataset.from_generator(generator, tf.float64)

这是因为tf.data.Dataset.from_generator()的 generator 参数不应该带任何参数。

任何解决方案?

我找到了一个基于名为Partially Applied Functions的函数式编程概念的解决方案。 总之:

PAF 是一个函数,它接受一个具有多个参数的函数并返回一个具有较少参数的函数。

我这样做的方式如下:

from functools import partial
import tensorflow as tf

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

def get_generator(lo, hi):
    return partial(generator, lo, hi)

tf.data.Dataset(get_generator(lo, hi), tf.float64)

get_generator(lo, hi)函数返回生成器的部分应用函数,该函数修复了lohi参数的值,这实际上是tf.data.Dataset.from_generator()所需的tf.data.Dataset.from_generator()参数生成器。

TensorFlow Dataset已经支持通过参数args对生成器进行参数化,该参数只是传递给您的生成器( 请参阅文档)。 这是在 TensorFlow 2.0.0上测试的最小工作示例。

import tensorflow as tf

x_train = [i for i in range(0, 20, 2)]  # even
x_val = [i for i in range(1, 20, 2)]  # odd
y_train = [i**2 for i in x_train]  # squared
y_val = [i**2 for i in x_val]

def gen_data_epoch(test=False):  # parametrized generator
    train_data = x_val if test else x_train
    label_data = y_val if test else y_train
    n_tests = len(train_data)
    for test_idx in range(len(train_data)):
        yield train_data[test_idx], label_data[test_idx]

def get_dataset(test=False):
    return tf.data.Dataset.from_generator(
        gen_data_epoch, args=(test,),
        output_types=(tf.int32, tf.int32))

print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM