簡體   English   中英

我的字符分類 model 在每種情況下都預測錯誤,我不確定是因為我的代碼還是我的數據

[英]My character classification model predicts wrong in every case, and I'm not sure if it's because of my code or my data

我正在構建一個 python 程序來從圖像中讀取電話號碼,我想通過一個一個地讀取每個數字然后在最后打印出來來做到這一點。 現在我正試圖讓它只與一個角色一起工作。 這是我用於數字 3 的訓練數據,rest 非常相似(請注意,我還包括 - 和 /,因為某些電話號碼包含這些字符)。

3 號數據集

當我嘗試在 model 中輸入一些數字進行預測時,它基本上是 0% 的准確度,我不確定我的代碼是否做錯了,還是我的訓練數據很糟糕? 一些輸入的例子:

輸入數據

最后是我的代碼:

import os
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

VALUES = {
    '0': '0',
    '1': '1',
    '2': '2',
    '3': '3',
    '4': '4',
    '5': '5',
    '6': '6',
    '7': '7',
    '8': '8',
    '9': '9',
    '10': '-',
    '11': '/',
}


def normalize(image, label):
    image = tf.cast(image / 255., tf.float32)
    return image, label


train_new_model = True

if train_new_model:
    print("Loading data set...")
    dataset = tf.keras.utils.image_dataset_from_directory(
        'dataset', image_size=(28, 28), color_mode='grayscale', label_mode='int', labels="inferred"
    )
    validation = tf.keras.utils.image_dataset_from_directory(
        'validation', image_size=(28, 28), color_mode='grayscale', label_mode='int', labels="inferred"
    )

    print("Normalizing data...")
    dataset = dataset.map(normalize)

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(units=128, activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(units=128, activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(units=len(VALUES), activation=tf.nn.softmax))

    print("Compiling model...")
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print("Training model...")
    history = model.fit(dataset, epochs=100, validation_data=validation)

    print(history.history.keys())
    # summarize history for accuracy
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    input()

    val_loss, val_acc = model.evaluate(dataset)
    print("Loss: ", end="")
    print(val_loss)
    print("Accuracy: ", end="")
    print(val_acc)

    print("Saving model...")
    model.save('digits.model')
else:
    # Load the model
    model = tf.keras.models.load_model('digits.model')

image_number = 1
while os.path.isfile('digits/digit{}.png'.format(image_number)):
    try:
        # Convert image to a flat black and white image, and invert them
        originalImage = cv2.imread('digits/digit{}.png'.format(image_number))
        grayImage = cv2.cvtColor(originalImage, cv2.COLOR_BGR2GRAY)
        (thresh, blackAndWhiteImage) = cv2.threshold(grayImage, 127, 255, cv2.THRESH_BINARY)
        img = 255 - blackAndWhiteImage
        img = np.invert(np.array([img]))

        # Predict and read from solution VALUES
        prediction = model.predict(img)
        solution = np.argmax(prediction)
        print("Predicted character:", end="")
        print(VALUES[str(solution)])

        # Show the image
        plt.imshow(img[0], cmap=plt.cm.binary)
        plt.show()

        input()  # Wait for input to continue
    except BaseException as error:
        print('An exception occurred: {}'.format(error))
        print("Error reading image! Proceeding with next image...")
    finally:
        image_number += 1

這是 model 損失圖:

在此處輸入圖像描述

對我來說突出的一個潛在問題是數據集中的一些圖像是倒置的。 (在您的第一個屏幕截圖中,我在白色背景上看到一些黑色 3,在黑色背景上看到一些白色 3)。 我知道如果 model 可以自己處理這種差異會很酷,但這可能不現實,至少在沒有幫助的情況下並非如此。 嘗試將您的數據集配對,以使數字的外觀保持一致,重新訓練,然后看看它的表現如何。

為了進一步診斷代碼,一些 output 會很好。 隨着網絡訓練,損失會發生什么? 是炸了還是沒收斂?

一旦確定了這個問題,您就可以開始考慮如何處理具有黑白背景的更多樣化的數據集。 也許您可以包含一個通道以檢測輸入圖像上的“類型”並將其標准化(即將黑白圖像轉換為黑白圖像,並將其傳遞給分類器)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM