繁体   English   中英

TensorFlow从mnist数据集中选择标签

[英]TensorFlow select Labels from mnist dataset

我正在使用tensorflow.examples.tutorials.mnist来训练一个有5个隐藏层的nn。

这是我训练神经网络的方式:

with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
    for iteration in range(len(mnist.test.labels)//batch_size):
        X_batch, y_batch = mnist.train.next_batch(batch_size)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
    acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
    acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
    print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

我想训练神经网络只识别从0到4的数字。我将logits图层更改为5个输出。

如何过滤TensorFlow提供的mnist数据集,以便只获取0到4之间的数字?

有很多方法可以做到这一点。 其中之一就是当你提取X_batch, y_batch = mnist.train.next_batch(batch_size) 在这一步,你的y_batch将有关于数字值的信息(数字值或数字的一个热点)。

您遍历批处理中的示例并检查数字是否是您关心的数字。 如果是,则将其添加到cleaned_up_batch 不是很有效但它会起作用。


回答评论:

它效率不高,因为您可能需要多次过滤相同的数据。 我不认为这会是一个问题,因为MNIST非常小。 通常的做法是只过滤一次,创建一个新的数据集并编写自己的函数来从中获取下一批(实际上非​​常简单,因为你只需从数据集中随机选择k个元素)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM