繁体   English   中英

具有部分混洗的Tensorflow数据集

[英]Tensorflow dataset with partial shuffle

根据文档,我正在使用TensorFlow的数据集API,并且对shuffle()方法感到困惑:

Dataset.shuffle()转换使用与tf.RandomShuffleQueue相似的算法对输入数据集进行随机混洗:它维护固定大小的缓冲区,并从该缓冲区中均匀地随机选择下一个元素。

如果我只是“部分地”改组我的数据集(例如,buffer_size <=元素数),我希望只有第一个buffer_size元素会被改组,但是事实并非如此,请参见示例:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
                         .shuffle(buffer_size=4, seed=42)
                         .batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer) 
    print('batch:', sess.run(el))

输出:

batch: [2 5]

为什么在这里5? 因为缓冲区大小只有4个? 前2个元素应该在1〜4之内吧? 我在这里想念什么?

谢谢

简短的答案是,可以在任何时候(包括在创建批处理过程中)补充随机缓冲区。

这是您的观察可能发生的方式:

  • 数据集从数据中读取前4个元素。 随机缓冲区现在包含[1、2、3、4]
  • 您请求两个元素(通过创建2个批次的数据集上的get_next())
  • 随机数据集选择2并将下一个元素读取到随机缓冲区中,该缓冲区现在包含[1、3、4、5]。
  • 随机数据集从缓冲区中选取5个。
  • 您的批次[2,5]被退回。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM