简体   繁体   English

从本地数据集为CNN创建TensorFlow数据集

[英]Create a TensorFlow Dataset for a CNN from a local dataset

I have a big dataset of B/W images with two classes where the name of the directory is the name of the class: 我有一个包含两个类的黑白图像的大数据集,其中目录的名称是类的名称:

  • the directory SELECTION contains all images with label = selection; 目录SELECTION包含所有标签为= selection的图像;
  • the directory NEUTRAL contains all images with label = neutral. 目录NEUTRAL包含所有标签=中性的图像。

I need to load all these images in a TensorFlow dataset for change the MNIST Dataset in this tutorial. 我需要将所有这些图像加载到TensorFlow数据集中,以更改教程中的MNIST数据集。

I've tried to follow this guide and it looks good but there is some problems that I don't know how to fix. 我尝试按照指南进行操作,看起来不错,但是有一些我不知道如何解决的问题。 Following the guide I'm arrived till here: 按照指南,我到达这里:

    from __future__ import absolute_import, division, print_function
    import os
    import pathlib
    import IPython.display as display
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    np.set_printoptions(threshold=np.nan)

    tf.enable_eager_execution()
    tf.__version__
    os.system('clear')

    #### some tries for the SELECTION dataset ####

    data_root = pathlib.Path('/Users/matteo/Desktop/DATASET_X/SELECTION/TRAIN_IMG')

    all_image_paths = []
    all_image_labels = []
    for item in data_root.iterdir():
        item_tmp = str(item)
        if 'selection.png' in item_tmp:
            all_image_paths.append(str(item))
            all_image_labels.append(0)

    image_count = len(all_image_paths)
    label_names = ['selection', 'neutral']
    label_to_index = dict((name, index) for index, name in enumerate(label_names))
    img_path = all_image_paths[0]
    img_raw = tf.read_file(img_path)

    img_tensor = tf.image.decode_png(
        contents=img_raw,
        channels=1
    )
    print(img_tensor.numpy().min())
    print(img_tensor.numpy().max())
    #### it works fine till here ####

    #### trying to make a function ####
    #### problems from here ####

    def load_and_decode_image(path):
        print('[LOG:load_and_decode_image]: ' + str(path))
        image = tf.read_file(path)

        image = tf.image.decode_png(
            contents=image,
            channels=3
        )

        return image


    image_path = all_image_paths[0]
    label = all_image_labels[0]

    image = load_and_decode_image(image_path)
    print('[LOG:image.shape]: ' + str(image.shape))

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

    print('shape: ', repr(path_ds.output_shapes))
    print('type: ', path_ds.output_types)
    print()
    print('[LOG:path_ds]:' + str(path_ds))

If I load only one item it works but when I try to do: 如果我仅加载一项,则可以使用,但是当我尝试这样做时:

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

if I print path_ds.shape it return shape: TensorShape([]) so it seems that it doesen't works. 如果我打印path_ds.shape它将返回shape: TensorShape([])所以看来它不起作用。 If I try to continue to follow the tutorial with this block 如果我尝试继续遵循此块的教程

image_ds = path_ds.map(load_and_decode_image, num_parallel_calls=AUTOTUNE)
plt.figure(figsize=(8, 8))
for n, image in enumerate(image_ds.take(4)):
    print('[LOG:n, image]: ' + str(n) + ', ' + str(image))
    plt.subplot(2, 2, n+1)
    plt.imshow(image)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(' selection'.encode('utf-8'))
    plt.title(label_names[label].title())
plt.show()

it give me the following error: 它给我以下错误:

It's not possible open ' < string >': The file was not found (file: // /Users/matteo/Documents/GitHub/Cnn_Genetic/cnn_genetic/<string > ).

but the problem is that I don't know what this file is and why it goes looking for it. 但是问题是我不知道这个文件是什么以及为什么要去寻找它。 I dont't neet to plot my images but I want to understand why it doesen't works. 我不需要绘制图像,但我想了解为什么它不起作用。 If I copy/paste the tutorial code i have the same problem so I think there's a problem with new tf version. 如果我复制/粘贴教程代码,我将遇到相同的问题,因此我认为新的tf版本存在问题。

So....if anyone can tell me where I'm going wrong, I'd be very grateful. 所以....如果有人可以告诉我我要去哪里错了,我将非常感激。 Thanks for your time. 谢谢你的时间。

Your issue is that path_ds should be the image paths as strings, but you try to convert them to a list of tensors. 您的问题是path_ds应该是字符串形式的图像路径,但是您尝试将它们转换为张量列表。

So to get the tensors you only need: 因此,只需要张量即可:

image_ds = all_image_paths.map(load_and_decode_image, num_parallel_calls=AUTOTUNE)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM