简体   繁体   English

如何在 Tensorflow 中读取 CIFAR-10 数据集?

[英]How to read CIFAR-10 dataset in Tensorflow?

Can anyone give a clean code to load CIFAR-10 in tensoflow?谁能给出一个干净的代码来在 tensoflow 中加载 CIFAR-10?

I have checked the examples given in the tensorflow's github repo.我已经检查了 tensorflow 的 github repo 中给出的示例。 But I do not want to resize the images to 24x24 .但我不想将图像大小调整为24x24 Basically, I am looking for a easier and simpler code.基本上,我正在寻找更简单的代码。

Please take a look at the following github page, where I have done this.请查看以下github页面,我已在其中完成此操作。 If the above link fails, please follow the lead on kgeorge.github.io and look at the notebook tf_cifar.ipynb.如果上面的链接失败,请按照kgeorge.github.io上的引导并查看笔记本 tf_cifar.ipynb。 I have attempted to load up cifar-10 data using baby steps.我曾尝试使用婴儿步骤加载 cifar-10 数据。 Please look for the function load_and_preprocess_input请寻找函数load_and_preprocess_input

The following function from that code accepts data as an np array of (nsamples, 32x32x3) float32, and labels as an np array of nsamples int32 and pre-process the data to be consumed by tensorflow training.该代码中的以下函数接受数据作为 (nsamples, 32x32x3) float32 的 np 数组,并将标签作为 nsamples int32 的 np 数组,并对要由 tensorflow 训练使用的数据进行预处理。

image_depth=3
image_height=32
image_width=32
#data = (nsamples, 32x32x3) float32
#labels = (nsamples) int32
def prepare_input(data=None, labels=None):
    global image_height, image_width, image_depth
    assert(data.shape[1] == image_height * image_width * image_depth)
    assert(data.shape[0] == labels.shape[0])
    #do mean normaization across all samples
    mu = np.mean(data, axis=0)
    mu = mu.reshape(1,-1)
    sigma = np.std(data, axis=0)
    sigma = sigma.reshape(1, -1)
    data = data - mu
    data = data / sigma
    is_nan = np.isnan(data)
    is_inf = np.isinf(data)
    if np.any(is_nan) or np.any(is_inf):
        print('data is not well-formed : is_nan {n}, is_inf: {i}'.format(n= np.any(is_nan), i=np.any(is_inf)))
    #data is transformed from (no_of_samples, 3072) to (no_of_samples , image_height, image_width, image_depth)
    #make sure the type of the data is no.float32
    data = data.reshape([-1,image_depth, image_height, image_width])
    data = data.transpose([0, 2, 3, 1])
    data = data.astype(np.float32)
    return data, labels

请注意,现在有一个内置函数来加载此数据集。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM