简体   繁体   中英

How to manipulate tfds.load() datasets correctly in tensorflow 2.x?

I am learning how to create a MNIST model from scratch in tensorflow 2.0 and Keras from a Udemy course.

So, I got the mnist dataset as follows

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

and everything was fine, even I got 97% accuracy testing my model and I was happy.

The problem started when I tried to do something different from the course. I tried to print some examples from mnist_dataset using matplotlib plt.imshow() and I totally failed. Then I started some research and I got a solution, I needed to get the dataset like this:

mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']

where mnistt is the dataset I can manipulate and print using matplotlib.

So my question is as follows: where can I get information about types of tfds.load() you can get and how to correctly manipulate them as you want? (and being somewhat extendible from a beginner in tensorflow like me).

The main invocation of the tfds.load method contains everything you need:

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
  • name="mnist" -> you're specifiying the builder you want to use (mnist)
  • with_info=True -> you're asking tfds.load to return the info object that contains all you need to know about the returned dataset
  • as_supervised=True -> you're asking tfds.load to get only the elements of the dataset needed for a supervised learning task (the image and label pair).

Your first attempt of using mnist_dataset to get the data (to use with matplotlib ) failed because as you can see from

print(mnist_info) #run me!

The dataset contains 2 different splits: train and test .

tfds.core.DatasetInfo(
    name='mnist',
    version=1.0.0,
    description='The MNIST database of handwritten digits.',
    urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

Thus, the object returned by tfds.load is a dictionary :

{
   "train": <train dataset>,
   "test": <test dataset>
}

In fact, in the next line of the example, you extract the "train" and "test" datasets in this way:

mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

From the mnist_info object, you can get every info you need to manipulate your dataset: the number of splits, the data type (eg "image" is a 28x28x1 image with dtype tf.uint8), etc...

I am getting error while loading mnist using this code

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)

Error init () missing 2 required positional arguments: 'op' and 'message'

Source Udemy course

try this

x_train, y_train = Next(iter(mnist_train))

then plot x_train

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM