[英]Tensorflow splitting training data to batches
I have a dataset of images as a Numpy array.我有一个图像数据集作为 Numpy 数组。 (Number of images, length, width, colour range) I would like to split it to batches and feed to tensorflow.
(图像数量、长度、宽度、颜色范围)我想将其拆分为批次并提供给 tensorflow。 What is the good way to do it?
这样做的好方法是什么?
First you could use numpy.split
to divide your images into batches (sub-ndarrays).首先,您可以使用
numpy.split
将图像分成批次(子数组)。 Then you could feed them to the tf.Session
using therun
function with the feed_dict
parameter.然后,你可以喂它们到
tf.Session
使用run
与功能feed_dict
参数。
I'd also highly recommend looking at the TF MNIST tutorial我还强烈建议您查看 TF MNIST 教程
There's a small error in Thomas Pinetz answer and I can't make comments yet, so here's an extra answer. Thomas Pinetz 的回答有一个小错误,我还不能发表评论,所以这里有一个额外的答案。
int(len(array)/batch_size)
will round the division down to the nearest integer, so the last batch wouldn't be processed. int(len(array)/batch_size)
会将除法向下舍入到最接近的整数,因此不会处理最后一批。 To round up the division you can use要四舍五入您可以使用的部门
ceil_int = -(-a//b)
In addition you might end up with the last batch being very tiny compared to the rest.此外,与其他批次相比,您最终可能会发现最后一批非常小。 You can modify your batch size slightly to make this less likely to happen.
您可以稍微修改批量大小以减少这种情况发生的可能性。 The complete code is shown below:
完整代码如下所示:
def ceil(a,b):
return -(-a//b)
n_samples = len(array)
better_batch_size = ceil(n_samples, ceil(n_samples, batch_size))
for i in range(ceil(n_samples, better_batch_size)):
batch = array[i * better_batch_size: (i+1) * better_batch_size]
I use something like this:我使用这样的东西:
for bid in range(int(len(array)/batch_size)):
batch = array[bid*batch_size:(bid+1)*batch_size]
If you already have created your dataset you can just use batch()
to create batches of the data.如果您已经创建了数据集,则可以使用
batch()
来创建数据批次。
>>>dataset = tf.data.Dataset.range(8)
>>>dataset = dataset.batch(3)
>>>list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
You can see more details in tensorflow documentation about batch()
您可以在tensorflow 文档中查看有关
batch()
更多详细信息
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.