[英]Parametrized generators while using tf.data.Dataset.from_generator()
我想為我的生成器提供參數以與tf.data.Dataset.from_generator()
結合使用。 例如:
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
這個生成器在lo
和hi
之間產生浮動。 但是請注意,在創建數據集時,這些參數永遠不會傳遞給此生成器。
tf.data.Dataset.from_generator(generator, tf.float64)
這是因為tf.data.Dataset.from_generator()
的 generator 參數不應該帶任何參數。
任何解決方案?
我找到了一個基於名為Partially Applied Functions的函數式編程概念的解決方案。 總之:
PAF 是一個函數,它接受一個具有多個參數的函數並返回一個具有較少參數的函數。
我這樣做的方式如下:
from functools import partial
import tensorflow as tf
def generator(lo, hi):
for i in range(lo, hi):
yield float(i)
def get_generator(lo, hi):
return partial(generator, lo, hi)
tf.data.Dataset(get_generator(lo, hi), tf.float64)
get_generator(lo, hi)
函數返回生成器的部分應用函數,該函數修復了lo
和hi
參數的值,這實際上是tf.data.Dataset.from_generator()
所需的tf.data.Dataset.from_generator()
參數生成器。
TensorFlow Dataset
已經支持通過參數args
對生成器進行參數化,該參數只是傳遞給您的生成器( 請參閱文檔)。 這是在 TensorFlow 2.0.0
上測試的最小工作示例。
import tensorflow as tf
x_train = [i for i in range(0, 20, 2)] # even
x_val = [i for i in range(1, 20, 2)] # odd
y_train = [i**2 for i in x_train] # squared
y_val = [i**2 for i in x_val]
def gen_data_epoch(test=False): # parametrized generator
train_data = x_val if test else x_train
label_data = y_val if test else y_train
n_tests = len(train_data)
for test_idx in range(len(train_data)):
yield train_data[test_idx], label_data[test_idx]
def get_dataset(test=False):
return tf.data.Dataset.from_generator(
gen_data_epoch, args=(test,),
output_types=(tf.int32, tf.int32))
print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.