简体   繁体   English

在 Keras 中使用 MNIST 数据

[英]Using MNIST data with Keras

I am currently playing around with MNIST data as part of course on using numpy and tensorflow.我目前正在使用 MNIST 数据作为使用 numpy 和 tensorflow 课程的一部分。 I was running the code they provided in course and I noticed a few warnings from tensorflow when running this snippet of code:我正在运行他们在课程中提供的代码,在运行这段代码时,我注意到来自 tensorflow 的一些警告:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("../data/mnist_data/", one_hot=True)

I looked into the documentation and read that this is deprecated and one should use MNIST from keras instead.我查看了文档并了解到这已被弃用,应该改用来自keras MNIST。 So I changed the above code to this所以我把上面的代码改成了这样

from keras.datasets import mnist
from keras.models import Sequential, load_model
from keras.layers.core import Dense, Dropout, Activation
from keras.utils import np_utils

(X_train, y_train), (X_test, y_test) = mnist.load_data()

my issue now is that in the course material they use this function:我现在的问题是在课程材料中他们使用了这个功能:

training_digits, training_labels = mnist.train.next_batch(5000)

that function next_batch() isn't available with keras and the original MNIST dataset is pretty large.该函数next_batch()在 keras 中不可用,并且原始 MNIST 数据集非常大。 Is there a clever way to this with keras? keras 有没有聪明的方法来解决这个问题?

Many thanks in advance!提前谢谢了!

您可以设置 batch_size 并使用单次迭代器(),如此处描述的Keras Mnist 文档

Use Sequential() from Keras.使用 Keras 中的 Sequential()。 This Sequential() has a method called fit(), there you can set batchSize in parameter.Seee the documentatuion: keras Sequential这个 Sequential() 有一个叫做 fit() 的方法,你可以在参数中设置 batchSize。参见文档: keras Sequential

The issue is that your tutorial is using a different API from the keras dataset API used in most current tutorials.问题是您的教程使用的 API 与大多数当前教程中使用的 keras 数据集 API 不同。 In using the keras.dataset API you are trying to 'cross the streams'.在使用keras.dataset API 时,您正试图“跨越流”。

You (broadly) have three options:你(广义上)有三个选择:

Option 1选项1

Just stick with your existing tutorial and ignore the deprecation warnings.只需坚持使用现有教程并忽略弃用警告。 Super straightforward but you may miss out on the benefits of the keras api (the new default) unless you intend to learn this later超级简单,但除非您打算稍后学习,否则您可能会错过 keras api(新默认值)的好处

Option 2选项 2

Switch entirely to the keras API and find a new tutorial.完全切换到keras API 并找到一个新教程。 This one is an MNIST example in just a few lines of code: 是一个只有几行代码的 MNIST 示例:

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)

If it's available to you, this is the option I'd recommend.如果它对您可用,这是我推荐的选项。 keras is the new default. keras是新的默认值。 Perhaps this isn't an option or you want to stick with your original course but I'd certainly recommend becoming familiar with keras soon.也许这不是一个选择,或者您想坚持原来的课程,但我当然建议您尽快熟悉keras

Option 3选项 3

Find a way to successfully 'cross the streams'.找到一种成功“跨越河流”的方法。

This is more tricky but certainly can be done.这更棘手,但肯定可以做到。 The keras.dataset for mnist is just a big array after all. keras.dataset的 keras.dataset 毕竟只是一个大数组。 You could look into the Dataset API (in particular load_from_tensor() and load_from_tensor_slices() ).您可以查看数据集 API(特别是load_from_tensor()load_from_tensor_slices() )。 These options would need a little bit of wrangling though because inherently (as you discovered) the dataset returned from the new method is a different type from that returned from the old ones.但是,这些选项需要一些争论,因为本质上(正如您发现的那样)从新方法返回的数据集与从旧方法返回的数据集类型不同。

UPDATE:更新:

The link in nag's answer provides a comprehensive example of doing this which I was unaware of previously! nag 的答案中链接提供了一个我以前不知道的综合示例!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM