简体   繁体   中英

training accuracy drops in tensorflow

I was trying to create a model for character recognition. This model was working fine with 28*28 dataset and for characters from 0-9 but it training accuracy is dropping if changed to 64*64 and characters ranges from 0-9, az, AZ. While iterating through accuracy it goes till 0.3 and then stays there afterwards. I tried to train with different dataset as well but the same thing is happening. Changing learning rate to 0.001 also does not help. Can anyone tell what is the issue with this?

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random as ran
import os
import tensorflow as tf

def TRAIN_SIZE(num):
    images = np.load("data/train/images64.npy").reshape([2852,4096])
    labels = np.load("data/train/labels.npy")
    print ('Total Training Images in Dataset = ' + str(images.shape))
    print ('--------------------------------------------------')
    x_train = images[:num,:]
    print ('x_train Examples Loaded = ' + str(x_train.shape))
    y_train = labels[:num,:]
    print ('y_train Examples Loaded = ' + str(y_train.shape))
    print('')
    return x_train, y_train

def TEST_SIZE(num):
    images = np.load("data/test/images64.npy").reshape([558,4096])
    labels = np.load("data/test/labels.npy")
    print ('Total testing Images in Dataset = ' + str(images.shape))
    print ('--------------------------------------------------')
    x_test = images[:num,:]
    print ('x_test Examples Loaded = ' + str(x_test.shape))
    y_test = labels[:num,:]
    print ('y_test Examples Loaded = ' + str(y_test.shape))
    print('')
    return x_test, y_test

def display_digit(num):
    # print(y_train[num])
    label = y_train[num].argmax(axis=0)
    image = x_train[num].reshape([64,64])
    # plt.axis("off")
    plt.title('Example: %d  Label: %d' % (num, label))
    plt.imshow(image, cmap=plt.get_cmap('gray_r'))
    plt.show()

def display_mult_flat(start, stop):
    images = x_train[start].reshape([1,4096])
    for i in range(start+1,stop):
        images = np.concatenate((images, x_train[i].reshape([1,4096])))
    plt.imshow(images, cmap=plt.get_cmap('gray_r'))
    plt.show()

def get_char(a):
    if(a<10):
        return a
    elif(a>=10 and a<36):
        return chr(a+55)
    else:
        return chr(a+61)

x_train, y_train = TRAIN_SIZE(2850)
x_test, y_test = TRAIN_SIZE(1900)

x = tf.placeholder(tf.float32, shape=[None, 4096])           
y_ = tf.placeholder(tf.float32, shape=[None, 62])
W = tf.Variable(tf.zeros([4096,62]))
b = tf.Variable(tf.zeros([62]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

with tf.Session() as sess:

    # x_test = x_test[1400:,:]
    # y_test = y_test[1400:,:]
    x_test, y_test =TEST_SIZE(400)
    LEARNING_RATE = 0.2
    TRAIN_STEPS = 1000

    sess.run(tf.global_variables_initializer())
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    training = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    for i in range(TRAIN_STEPS+1):
        sess.run(training, feed_dict={x: x_train, y_: y_train}) 
        if i%100 == 0:
            print('Training Step:' + str(i) + '  Accuracy =  ' + str(sess.run(accuracy, feed_dict={x: x_test, y_: y_test})) + '  Loss = ' + str(sess.run(cross_entropy, {x: x_train, y_: y_train})))

    savedPath = tf.train.Saver().save(sess, "/tmp/model.ckpt")
    print("Model saved at: " ,savedPath)

You are trying to classify 62 different numbers and characters, but use a single fully connected layer to do that. Your model simply has not enough parameters for that task. In other words, you are underfitting the data. So either expand your network by adding parameters (layers) and/or use CNNs, which generally have good performance for image classification tasks.

Try different CNN mode. the model you are using like inception v1, v2,v3 alexnet etc..

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM