[英]How to make tf.data.Dataset.from_generator yield batches with a custom generator
我想使用tf.data
API。 我期望的工作流程如下所示:
输入图像是具有(batch_size, width, height, channels, frames)
的5D张量
第一层是3D卷积
我使用tf.data.from_generator
函数创建一个迭代器。 后来我做了一个初始化的迭代器。
我的代码如下所示:
def custom_gen():
img = np.random.normal((width, height, channels, frames))
yield(img, img) # I train an autoencoder, so the x == y`
dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(iter.get_next())
我希望iter.get_next()
我生成具有批处理大小的5D张量。 但是,我什至尝试在自己的custom_generator
产生批处理大小,但它不起作用。 当我想用输入形状(batch_size, width, height, channels, frames)
占位符初始化数据集时,遇到一个错误。
该示例中的Dataset
构建过程Dataset
不正确。 应当按照导入数据的官方指南中规定的顺序进行:
from_slice_tensors
, from_generator
, list_files
等)。 batch
)来应用转换 。 从而:
dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.