简体   繁体   English

来自生成器的张量流数据集

[英]tensorflow dataset from generator

I am using such code to recursively load images from directory and get associated labels - directory names. 我正在使用这样的代码从目录中递归加载图像并获取相关标签-目录名称。 But when I have more images, it crashes due memory error. 但是,当我有更多图像时,它会由于内存错误而崩溃。 I would like to use generator, but I am really stuck with it. 我想使用发电机,但是我真的很固执。 Could somebody help? 有人可以帮忙吗? The code without generator is: 没有生成器的代码是:

import pathlib
import random

data_dir = "./images"

print(data_dir)
data_root = pathlib.Path(data_dir)
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index,name in enumerate(label_names))

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=8)
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
dataset = tf.data.Dataset.zip((image_ds, label_ds))
dataset = dataset.shuffle(params.train_size)buffer
dataset = dataset.repeat(params.num_epochs)
dataset = dataset.batch(params.batch_size)
dataset = dataset.prefetch(1) to serve

return dataset

You don't need to use tf.data.Dataset.from_generator . 您不需要使用tf.data.Dataset.from_generator Creating dataset from images with tf.data.Dataset.from_tensor_slices writes data chunks in the graph as tf.constant(), wasting memory. 使用tf.data.Dataset.from_tensor_slices从图像创建数据tf.data.Dataset.from_tensor_slices将数据块写入tf.constant(),浪费内存。 With a large enough dataset you can hit Tensorflow's 2GB GraphDef limit. 有了足够大的数据集,您可以达到Tensorflow的2GB GraphDef限制。 You just need to define dataset using placeholders 您只需要使用占位符定义数据集

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

As explained here https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays 如此处所述https://www.tensorflow.org/guide/datasets#using_numpy_arrays

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM