繁体   English   中英

如何正确使用Tensorflow数据集和批处理?

[英]How to use properly Tensorflow Dataset with batch?

我是Tensorflow和深度学习的新手,我正在努力学习数据集课程。 我尝试了很多东西,但我找不到一个好的解决方案。

我在想什么

我有大量的图像(500k +)来训练我的DNN。 这是一个去噪自动编码器,所以我有一对每个图像。 我正在使用TF的数据集类来管理数据,但我认为我使用它非常糟糕。

以下是我在数据集中加载文件名的方法:

class Data:
def __init__(self, in_path, out_path):
    self.nb_images = 512
    self.test_ratio = 0.2
    self.batch_size = 8

    # load filenames in input and outputs
    inputs, outputs, self.nb_images = self._load_data_pair_paths(in_path, out_path, self.nb_images)

    self.size_training = self.nb_images - int(self.nb_images * self.test_ratio)
    self.size_test = int(self.nb_images * self.test_ratio)

    # split arrays in training / validation
    test_data_in, training_data_in = self._split_test_data(inputs, self.test_ratio)
    test_data_out, training_data_out = self._split_test_data(outputs, self.test_ratio)

    # transform array to tf.data.Dataset
    self.train_dataset = tf.data.Dataset.from_tensor_slices((training_data_in, training_data_out))
    self.test_dataset = tf.data.Dataset.from_tensor_slices((test_data_in, test_data_out))

我有一个函数可以在每个时代调用来准备数据集。 它会对文件名进行随机播放,并将文件名转换为图像和批处理数据。

def get_batched_data(self, seed, batch_size):
    nb_batch = int(self.size_training / batch_size)

    def img_to_tensor(path_in, path_out):
        img_string_in = tf.read_file(path_in)
        img_string_out = tf.read_file(path_out)
        im_in = tf.image.decode_jpeg(img_string_in, channels=1)
        im_out = tf.image.decode_jpeg(img_string_out, channels=1)
        return im_in, im_out

    t_datas = self.train_dataset.shuffle(self.size_training, seed=seed)
    t_datas = t_datas.map(img_to_tensor)
    t_datas = t_datas.batch(batch_size)
    return t_datas

现在在训练期间,在每个时代我们调用get_batched_data函数,创建一个迭代器,并为每个批处理运行它,然后将数组提供给优化器操作。

for epoch in range(nb_epoch):
    sess_iter_in = tf.Session()
    sess_iter_out = tf.Session()

    batched_train = data.get_batched_data(epoch)
    iterator_train = batched_train.make_one_shot_iterator()
    in_data, out_data = iterator_train.get_next()

    total_batch = int(data.size_training / batch_size)
    for batch in range(total_batch):
        print(f"{batch + 1} / {total_batch}")
        in_images = sess_iter_in.run(in_data).reshape((-1, 64, 64, 1))
        out_images = sess_iter_out.run(out_data).reshape((-1, 64, 64, 1))
        sess.run(optimizer, feed_dict={inputs: in_images,
                                       outputs: out_images})

我需要什么 ?

我需要有一个只加载当前批次图像的管道(否则它不适合内存)我希望以不同的方式为每个时代改组数据集。

问题和问题

第一个问题,我是否以良好的方式使用数据集课程? 我在互联网上看到了非常不同的东西,例如在这篇博客文章中,数据集与占位符一起使用,并在学习数据时使用。 这看起来很奇怪,因为数据都在一个数组中,所以加载到内存中。 在这种情况下,我没有看到使用tf.data.dataset

我通过在数据集上使用repeat(epoch)找到了解决方案,就像这样 ,但在这种情况下,每个时期的shuffle都不会有所不同。

我的实现的第二个问题是在某些情况下我有一个OutOfRangeError 使用少量数据(如示例中的512),它可以正常工作,但是如果数据量较大,则会发生错误。 我认为这是因为糟糕的舍入导致批次数计算错误,或者最后一批数据量较少,但是它发生在115中的批次32中...有没有办法知道batch(n)调用数据集后创建的batch(n)次数?

对不起,这个loooonng问题,但我已经困难了几天。

据我所知, 官方绩效指南是制作输入管道的最佳教材。

我想以不同的方式为每个时代改组数据集。

使用shuffle()和repeat(),您可以为每个时期获得不同的随机播放模式。 您可以使用以下代码进行确认

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4])
dataset = dataset.shuffle(4)
dataset = dataset.repeat(3)

iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(x))

您也可以使用上述官方页面中提到的tf.contrib.data.shuffle_and_repeat。

除了创建数据管道之外,您的代码中存在一些问题。 您将图构造与图执行混淆。 您正在重复创建数据输入管道,因此有许多冗余输入管道和纪元一样多。 您可以通过Tensorboard观察冗余管道。

您应该将图形构造代码放在循环外部,如下面的代码(伪代码)

batched_train = data.get_batched_data()
iterator = batched_train.make_initializable_iterator()
in_data, out_data = iterator_train.get_next()

for epoch in range(nb_epoch):
    # reset iterator's state
    sess.run(iterator.initializer)

    try:
        while True:
            in_images = sess.run(in_data).reshape((-1, 64, 64, 1))
            out_images = sess.run(out_data).reshape((-1, 64, 64, 1))
            sess.run(optimizer, feed_dict={inputs: in_images,
                                           outputs: out_images})
    except tf.errors.OutOfRangeError:
        pass

此外,还有一些不重要的低效代码。 您使用from_tensor_slices()加载了文件路径列表,因此列表已嵌入到图表中。 (有关详细信息,请参阅https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays

最好使用预取,并通过组合图形来减少sess.run调用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM