I am learning how to create a MNIST model from scratch in tensorflow 2.0 and Keras from a Udemy course.
So, I got the mnist dataset as follows
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
and everything was fine, even I got 97% accuracy testing my model and I was happy.
The problem started when I tried to do something different from the course. I tried to print some examples from mnist_dataset using matplotlib plt.imshow()
and I totally failed. Then I started some research and I got a solution, I needed to get the dataset like this:
mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']
where mnistt
is the dataset I can manipulate and print using matplotlib.
So my question is as follows: where can I get information about types of tfds.load() you can get and how to correctly manipulate them as you want? (and being somewhat extendible from a beginner in tensorflow like me).
The main invocation of the tfds.load
method contains everything you need:
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
name="mnist"
-> you're specifiying the builder you want to use (mnist) with_info=True
-> you're asking tfds.load
to return the info
object that contains all you need to know about the returned dataset as_supervised=True
-> you're asking tfds.load
to get only the elements of the dataset needed for a supervised learning task (the image and label pair). Your first attempt of using mnist_dataset
to get the data (to use with matplotlib
) failed because as you can see from
print(mnist_info) #run me!
The dataset contains 2 different splits: train
and test
.
tfds.core.DatasetInfo(
name='mnist',
version=1.0.0,
description='The MNIST database of handwritten digits.',
urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)
Thus, the object returned by tfds.load
is a dictionary :
{
"train": <train dataset>,
"test": <test dataset>
}
In fact, in the next line of the example, you extract the "train" and "test" datasets in this way:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
From the mnist_info
object, you can get every info you need to manipulate your dataset: the number of splits, the data type (eg "image" is a 28x28x1 image with dtype tf.uint8), etc...
I am getting error while loading mnist using this code
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
Error init () missing 2 required positional arguments: 'op' and 'message'
Source Udemy course
try this
x_train, y_train = Next(iter(mnist_train))
then plot x_train
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.