简体   繁体   English

使用 tf.data.Dataset.from_generator() 时的参数化生成器

[英]Parametrized generators while using tf.data.Dataset.from_generator()

I would like to give parameters to my generator to use in combination with tf.data.Dataset.from_generator() .我想为我的生成器提供参数以与tf.data.Dataset.from_generator()结合使用。 For example:例如:

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

This generator yields floats between lo and hi .这个生成器在lohi之间产生浮动。 Notice however than when creating a Dataset, these parameters are never passed to this generator.但是请注意,在创建数据集时,这些参数永远不会传递给此生成器。

tf.data.Dataset.from_generator(generator, tf.float64)

This is because the generator parameter of tf.data.Dataset.from_generator() should take no arguments.这是因为tf.data.Dataset.from_generator()的 generator 参数不应该带任何参数。

Any solutions?任何解决方案?

I found a solution based on a functional programming concept called Partially Applied Functions .我找到了一个基于名为Partially Applied Functions的函数式编程概念的解决方案。 In summary:总之:

a PAF is a function that takes a function with multiple parameters and returns a function with fewer parameters. PAF 是一个函数,它接受一个具有多个参数的函数并返回一个具有较少参数的函数。

The way I did it is the following:我这样做的方式如下:

from functools import partial
import tensorflow as tf

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

def get_generator(lo, hi):
    return partial(generator, lo, hi)

tf.data.Dataset(get_generator(lo, hi), tf.float64)

The get_generator(lo, hi) function returns a partially applied function for the generator which fixes the values for the lo and hi parameters, which is in fact the parameterless generator required by tf.data.Dataset.from_generator() . get_generator(lo, hi)函数返回生成器的部分应用函数,该函数修复了lohi参数的值,这实际上是tf.data.Dataset.from_generator()所需的tf.data.Dataset.from_generator()参数生成器。

TensorFlow Dataset already supports parametrizing the generator through the argument args which is simply passed to your generator ( see docs ). TensorFlow Dataset已经支持通过参数args对生成器进行参数化,该参数只是传递给您的生成器( 请参阅文档)。 Here is a minimal working example tested on TensorFlow 2.0.0 .这是在 TensorFlow 2.0.0上测试的最小工作示例。

import tensorflow as tf

x_train = [i for i in range(0, 20, 2)]  # even
x_val = [i for i in range(1, 20, 2)]  # odd
y_train = [i**2 for i in x_train]  # squared
y_val = [i**2 for i in x_val]

def gen_data_epoch(test=False):  # parametrized generator
    train_data = x_val if test else x_train
    label_data = y_val if test else y_train
    n_tests = len(train_data)
    for test_idx in range(len(train_data)):
        yield train_data[test_idx], label_data[test_idx]

def get_dataset(test=False):
    return tf.data.Dataset.from_generator(
        gen_data_epoch, args=(test,),
        output_types=(tf.int32, tf.int32))

print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 tf.data.Dataset.from_generator 时出错 - Error when using tf.data.Dataset.from_generator 使用 tf.data.Dataset.from_generator() 从生成器加载数据 - Loading data from generator using tf.data.Dataset.from_generator() 如何加速 tf.data.Dataset.from_generator() - How to speed up tf.data.Dataset.from_generator() 在 tf.data.Dataset.from_generator() 上应用扩充 - Apply augmentation on tf.data.Dataset.from_generator() 如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数? - How do you send arguments to a generator function using tf.data.Dataset.from_generator()? 如何在 tf.data.Dataset.from_generator 中保留字典键? - How to preserve dict keys in tf.data.Dataset.from_generator? 使用tf.data.Dataset.from_generator时出现“ SystemError:没有设置异常的错误返回” - “SystemError: error return without exception set” when using tf.data.Dataset.from_generator 如何使用 tf.data.Dataset.from_generator() 从数据集中一次只加载一批? - How to use tf.data.Dataset.from_generator() to load only one batch at a time from the dataset? 如何使用自定义生成器使tf.data.Dataset.from_generator产生批处理 - How to make tf.data.Dataset.from_generator yield batches with a custom generator 从不同数组形状的 tf.data.Dataset.from_generator() 创建一个 padded_batch - create a padded_batch from tf.data.Dataset.from_generator() of different array shapes
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM