简体   繁体   English

在Tensorflow中实施CIFAR10时遇到问题

[英]Have an Issue Implementing CIFAR10 in Tensorflow

import os
import numpy as np
import pickle

class CifarLoader(object):
    def __init__(self, source_files):
        self._source = source_files
        self._i = 0
        self.images = None
        self.labels = None

    def load(self):
        data = [unpickle(f) for f in self._source] #again a list comprehension
        images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
        n = len(images)
        self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel
        self.labels = one_hot(np.hstack([d["labels"] for d in data]), 10)
        return self

    def next_batch(self, batch_size):
        x, y = self.images[self._i:self._i+batch_size], self.labels[self._i:self._i+batch_size]
        self._i = (sel._i + batch_size) % len(self.images)
        return x, y

DATA_PATH = "cifar10"

def unpickle(file):
    with open(os.path.join(DATA_PATH, file), 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def one_hot(vec, vals=10):
    n = len(vec)
    out = np.zeros((n, vals))
    out[range(n), vec] = 1
    return out

class CifarDataManager(object):
    def __init__(self):
        self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
        self.test = CifarLoader(["test_batch"]).load()

def display_cifar(images, size):
    n = len(images)
    plt.figure()
    plt.gca().set_axis_off()
    im = np.vstack([np.hstack([images[np.random.choice(n)] for i in range(size)]) for i in range(size)])
    plt.imshow(im)
    plt.show()

d = CifarDataManager()

print ("Number of train images: {}".format(len(d.train.images)))
print ("Number of train labels: {}".format(len(d.train.labels)))
print ("Number of test images: {}".format(len(d.test.images)))
print ("Number of test images: {}".format(len(d.test.labels)))
images = d.train.images
display_cifar(images, 10)

And this is the error I'm getting. 这就是我得到的错误。

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-182-b3f5a6bd2e1d> in <module>()
      7     plt.show()
      8 
----> 9 d = CifarDataManager()
     10 
     11 print ("Number of train images: {}".format(len(d.train.images)))

<ipython-input-181-e85d41d02848> in __init__(self)
      1 class CifarDataManager(object):
      2     def __init__(self):
----> 3         self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
      4         self.test = CifarLoader(["test_batch"]).load()

<ipython-input-179-d96c4afcda51> in load(self)
     12     def load(self):
     13         data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14         images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
     15         n = len(images)
     16         self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel

<ipython-input-179-d96c4afcda51> in <listcomp>(.0)
     12     def load(self):
     13         data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14         images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
     15         n = len(images)
     16         self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel

KeyError: 'data'

Any help is appreciated! 任何帮助表示赞赏! I suspect the issue has to do with pickle and Python3 and the way it loads the data. 我怀疑这个问题与pickle和Python3及其加载数据的方式有关。

Thank you for checking out your files and posting the result. 感谢您检出文件并发布结果。 It is clear now your key is bytes string (bytes). 现在很清楚,您的密钥是字节字符串(bytes)。 Since you didn't specify, I can only guess you are using python3 which can't convert bytes object to string implicitly (see the note in this section ). 由于您未指定,所以我只能猜测您使用的是python3,它无法将字节对象隐式转换为字符串(请参阅本节中的注释)。 Try the following under python 2 and python 3, and you may have a better idea: 在python 2和python 3下尝试以下操作,可能会有更好的主意:

d = {b'a': 1, b'b': 2}
print(d.keys())
try:
    print('Key "a" gives: {}'.format(d["a"]))
except Exception as err:
    print('Get "{}"!'.format(err.__class__.__name__))
    print('Key b"a" gives: {}'.format(d[b"a"]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM