简体   繁体   中英

Low accuracy in neural network with Tensorflow

I was following a Google code lab on neural networks and I decided to use the Cifar10 dataset instead of the MNIST dataset to make a simple image classifier, but for some reason I have very low accuracy and high cross-entropy.

After training the accuracy is around 0.1 (never more than 0.2) and cross-entropy doesn't go below 230.

My code:

import tensorflow as tf
import numpy as np
import matplotlib as mpt
# Just disables the warning, doesn't enable AVX/FMA
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def returnMiniBatch(dictionary,start,number):
    matrix=np.zeros([number,3072],dtype=np.int)
    labels=np.zeros([number],dtype=np.int)
    for i in range(0,number):
        matrix[i]=dictionary[b'data'][i+start]
        labels[i]=dictionary[b'labels'][i+start]
    return matrix,labels

def formatLabels(labels,number):
    lab=np.zeros([number,10])
    for i in range(0,number):
        lab[i][labels[i]]=1
    return lab

data='D:/cifar-10-python/cifar-10-batches-py/data_batch_1'
dictionary=unpickle(data)
tf.set_random_seed(0)

L = 200
M = 100
N = 60
O = 30


X=tf.placeholder(tf.float32,[None,3072])
Y_=tf.placeholder(tf.float32,[None,10])



W1 = tf.Variable(tf.truncated_normal([3072,L],stddev=0.1))
B1 = tf.Variable(tf.ones([L])/10)
W2 = tf.Variable(tf.truncated_normal([L, M], stddev=0.1))
B2 = tf.Variable(tf.ones([M])/10)
W3 = tf.Variable(tf.truncated_normal([M, N], stddev=0.1))
B3 = tf.Variable(tf.ones([N])/10)
W4 = tf.Variable(tf.truncated_normal([N, O], stddev=0.1))
B4 = tf.Variable(tf.ones([O])/10)
W5 = tf.Variable(tf.truncated_normal([O, 10], stddev=0.1))
B5 = tf.Variable(tf.ones([10]))

Y1 = tf.nn.relu(tf.matmul(X, W1) + B1)
Y2 = tf.nn.relu(tf.matmul(Y1, W2) + B2)
Y3 = tf.nn.relu(tf.matmul(Y2, W3) + B3)
Y4 = tf.nn.relu(tf.matmul(Y3, W4) + B4)


Ylogits = tf.matmul(Y4, W5) + B5
Y = tf.nn.softmax(Ylogits)

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits,         
                labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy)*100

correct_prediction=tf.equal(tf.argmax(Y,1),tf.argmax(Y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

train_step = tf.train.AdamOptimizer(0.003).minimize(cross_entropy)

init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)

def training_step(i):
    global dictionary
    val,lab=returnMiniBatch(dictionary,i * 100,100)
    Ylabels=formatLabels(lab,100)
    _,a,c = sess.run([train_step,accuracy, cross_entropy], feed_dict={X: 
                     val, Y_: Ylabels})
    print("Accuracy: ",a)
    print("Cross-Entropy",c)

for i in range (0,100):
    training_step(i%100)

If I'm not mistaken, that looks like a non-convolutional network. You need to look for a convolutional network architecture. So look for some tutorial using conv2d.

Reason: MNIST is single channel, binary data. CIFAR is 3 channels (RGB) with 8 bit colour. It's not enough to just up the size of the input placeholder. You need to tell the network that the three channels (and neighbouring pixels) are related. You do this by using a convolutional network architecture.

0.1 suggests no better than random chance. The network isn't learning anything that generalises.

The solution was to normalize the input data. I added a new function to normalize the data

def formatData(values):
    ret = values.reshape(100,3072).astype("float32")
    ret/=255
    return ret

and formatted the data before adding it to the feed dictionary.

X_Data=formatData(val)
    _,a,c = sess.run([train_step,accuracy, cross_entropy], feed_dict={X: 
                     X_Data, Y_: Ylabels})

After this change the network started learning properly (a convolutional network is still much beter here as Pam pointer out).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM