[英]How to make tf.data.Dataset.from_generator yield batches with a custom generator
I want to use the tf.data
API. 我想使用tf.data
API。 My expected workflow looks like the following: 我期望的工作流程如下所示:
Input image is a 5D tensor with (batch_size, width, height, channels, frames)
输入图像是具有(batch_size, width, height, channels, frames)
的5D张量
First layer is a 3D convolution 第一层是3D卷积
I use the tf.data.from_generator
function to create an iterator. 我使用tf.data.from_generator
函数创建一个迭代器。 Later I make a initializable iterator. 后来我做了一个初始化的迭代器。
My code would look something like this: 我的代码如下所示:
def custom_gen():
img = np.random.normal((width, height, channels, frames))
yield(img, img) # I train an autoencoder, so the x == y`
dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(iter.get_next())
I would expect that iter.get_next()
yielded me a 5D tensor with the batch size. 我希望iter.get_next()
我生成具有批处理大小的5D张量。 However, I even tried to yield the batch size in my own custom_generator
and it does not work. 但是,我什至尝试在自己的custom_generator
产生批处理大小,但它不起作用。 I face an error, when I want to initialize the dataset with my placeholder of the input shape (batch_size, width, height, channels, frames)
. 当我想用输入形状(batch_size, width, height, channels, frames)
占位符初始化数据集时,遇到一个错误。
The Dataset
construction process in that example is ill-formed. 该示例中的Dataset
构建过程Dataset
不正确。 It should be done in this order, as also established by the official guide on Importing Data : 应当按照导入数据的官方指南中规定的顺序进行:
from_slice_tensors
, from_generator
, list_files
, ...). 应该调用基本数据集创建函数或静态方法来建立原始数据源 (例如,静态方法from_slice_tensors
, from_generator
, list_files
等)。 batch
). 此时,可以通过链接适配器方法(例如batch
)来应用转换 。 Thus: 从而:
dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.