簡體   English   中英

使用 tf.data.Dataset.from_generator() 時的參數化生成器

[英]Parametrized generators while using tf.data.Dataset.from_generator()

我想為我的生成器提供參數以與tf.data.Dataset.from_generator()結合使用。 例如:

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

這個生成器在lohi之間產生浮動。 但是請注意,在創建數據集時,這些參數永遠不會傳遞給此生成器。

tf.data.Dataset.from_generator(generator, tf.float64)

這是因為tf.data.Dataset.from_generator()的 generator 參數不應該帶任何參數。

任何解決方案?

我找到了一個基於名為Partially Applied Functions的函數式編程概念的解決方案。 總之:

PAF 是一個函數,它接受一個具有多個參數的函數並返回一個具有較少參數的函數。

我這樣做的方式如下:

from functools import partial
import tensorflow as tf

def generator(lo, hi):
    for i in range(lo, hi):
        yield float(i)

def get_generator(lo, hi):
    return partial(generator, lo, hi)

tf.data.Dataset(get_generator(lo, hi), tf.float64)

get_generator(lo, hi)函數返回生成器的部分應用函數,該函數修復了lohi參數的值,這實際上是tf.data.Dataset.from_generator()所需的tf.data.Dataset.from_generator()參數生成器。

TensorFlow Dataset已經支持通過參數args對生成器進行參數化,該參數只是傳遞給您的生成器( 請參閱文檔)。 這是在 TensorFlow 2.0.0上測試的最小工作示例。

import tensorflow as tf

x_train = [i for i in range(0, 20, 2)]  # even
x_val = [i for i in range(1, 20, 2)]  # odd
y_train = [i**2 for i in x_train]  # squared
y_val = [i**2 for i in x_val]

def gen_data_epoch(test=False):  # parametrized generator
    train_data = x_val if test else x_train
    label_data = y_val if test else y_train
    n_tests = len(train_data)
    for test_idx in range(len(train_data)):
        yield train_data[test_idx], label_data[test_idx]

def get_dataset(test=False):
    return tf.data.Dataset.from_generator(
        gen_data_epoch, args=(test,),
        output_types=(tf.int32, tf.int32))

print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM