簡體   English   中英

當我測試為多個類別數據集訓練的圖像分類模型時,為什么會得到相同的類別預測?

[英]Why am I getting same class prediction when I test image classification model trained for multiple class datasets?

我正在嘗試使用具有 4 個不同類的 tf.data 在花卉圖像數據集上構建圖像分類模型。 當我測試經過訓練的模型時,即使對於不同的類圖像,我也會得到相同的類預測,但是訓練進行得很順利,具有良好的訓練准確度和驗證准確度,並且它在測試數據集上也提供了很好的准確度。

我的訓練和測試管道的實現如下

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)
data_dir = pathlib.Path(r'C:\Users\Hilary\.keras\datasets\flower_photos')
slide_labels =os.listdir(data_dir)

CLASS_NAMES = slide_labels 
NUM_CLASSES = len(CLASS_NAMES)
num_examples = len(list(data_dir.glob('*/*.jpg')))

def get_label(file_path):
  # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
    return tf.where(parts[-2] == CLASS_NAMES)[0][0]

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    return img  

def process_path(file_path):
    label = get_label(file_path)
  # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    features = {'image': img, 'label': label}
    return features

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
ds = list_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
print("Total number of images", len(ds))
print("total No of classes: ",NUM_CLASSES)


train_size = int(0.90 * num_examples)
val_size = int(0.09* num_examples)
test_size = int(0.01 * num_examples)

full_dataset = ds.shuffle(reshuffle_each_iteration=False, buffer_size=len(ds))
train_dataset = full_dataset.take(train_size)
test_val_dataset = full_dataset.skip(train_size)
val_dataset = test_val_dataset.take(val_size)
test_dataset = test_val_dataset.skip(val_size)
print("Number of examples on training set is ", len(train_dataset))

當我對單個班級圖像進行推斷時,如下所示

for img_f in list(paths.list_images(r'C:\Users\hillary\.keras\datasets\flower_photos\sunflowers')):
    img = cv2.imread(img_f)
    test_img = [img]
    # test_img = [np.expand_dims(img, axis=0) for img in test_img]
    test_img = tf.concat(test_img, axis=0)
    test_img = tf.image.resize(test_img, [128, 128])
    test_img = tf.cast(image, tf.float32) / 255.0
    # test_img = tf.expand_dims(image, axis = 0)
    logits = model(test_img)
    y_probabilities = tf.nn.softmax(logits).numpy()[0]
    print(y_probabilities)
    index_max_proba = np.argmax(tf.nn.softmax(logits))
    print(class_labels[index_max_proba])

我得到的結果為

[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses
[1.2498085e-03 1.7629927e-01 8.2240731e-01 3.1031032e-05 1.2520954e-05]
roses

它預測向日葵類圖像的玫瑰與其他類圖像相同

我針對不同的數據集和模型測試了這條管道,我得到了相同的結果,它們是針對不同類圖像的單類預測..

任何糾正我的錯誤的幫助或建議將不勝感激

您正在使用 opencv 加載文件,該文件以 BGR 格式加載圖像,而在原始管道中使用 tf.io 加載它。 嘗試使用以下代碼將其轉換為 RGB

img = cv2.imread(img_f)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
test_img = [img]

有趣的是,輸出都是相同的,而不僅僅是相似。

為什么在測試中將圖像值除以 255 而不是在訓練中?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM