简体   繁体   English

如何取消批处理 Tensorflow 2.0 数据集

[英]How to unbatch a Tensorflow 2.0 Dataset

I have a dataset which I create with the following code working with tf.data.Dataset :我有一个使用tf.data.Dataset使用以下代码创建的数据集:

dataset = Dataset.from_tensor_slices(corona_new)
dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))
dataset = dataset.map(lambda x: tf.transpose(x))

for i in dataset:
    print(i.numpy())
    break

which when I run it I get the following output (this is an example of one batch):当我运行它时,我得到以下 output (这是一批的例子):

[[  0. 125. 111. 232. 164. 134. 235. 190.] 
 [  0.  14.  16.   7.   9.   7.   6.   8.]
 [  0. 132. 199. 158. 148. 141. 179. 174.]
 [  0.   0.   0.   2.   0.   2.   1.   2.]
 [  0.   0.   0.   0.   3.   5.   0.   0.]]

How can I unbatch them?我怎样才能取消它们的批处理?

Found my solution.找到了我的解决方案。

In TensorFlow 2.0 you can unbatch a tf.data.Dataset by calling the .unbatch() function.在 TensorFlow 2.0 中,您可以通过调用tf.data.Dataset .unbatch() function 来取消批处理 tf.data.Dataset。

example: dataset.unbatch()示例: dataset.unbatch()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM