简体   繁体   中英

tensorflow mnist neural network model low accuracy

I am getting only 9 percent accuracy for this basic mnist model.

Can anyone help me to understand what am I doing wrong here?

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.reshape(x_train, (60000, 784))
x_test  = np.reshape(x_test, (10000, 784))
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
w = tf.Variable(tf.zeros((784, 10)))
b = tf.Variable(tf.zeros((10)))
y_hat = tf.nn.softmax(tf.matmul(x, w) + b)

cross_entropy = tf.reduce_mean(-tf.reduce_sum((y * tf.log(y_hat)), axis=1))
training_gd = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.initializers.global_variables().run()

for _ in range(10000):
    indices = np.random.randint(0, len(x_train), 100)
    batch_xs, batch_ys =  x_train[indices], y_train[indices]
    sess.run(training_gd, feed_dict={x: batch_xs, y: batch_ys})

correct_prediction = tf.equal(tf.argmax(y_hat, axis=1), tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: x_test, y: y_test}))
0.098

Computing cross entropy with:

cross_entropy = tf.reduce_mean(-tf.reduce_sum((y * tf.log(y_hat)), axis=1))

is numerically unstable.

You can normalize images between 0 and 1.

x_train = np.reshape(x_train, (60000, 784)) / 255.0
x_test  = np.reshape(x_test, (10000, 784)) / 255.0

Or compute cross entropy with TensorFlow API

logits = tf.matmul(x, w) + b
y_hat = tf.nn.softmax(logits)

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM