[英]Prediction labels are off - KERAS / Tensorflow
我正在使用 KERAS/TF 制作一個帶有遷移學習的圖像分類器,包括預訓練的模型權重。 圖像數據集拆分為 80/10/10。 類別是范圍從 1 - 80 的字符串標簽。
對於圖像的預處理,我使用 ImageDataGenerator 並且評估顯示了大約 58% 的評估集的良好准確性。 但是,嘗試預測評估集上的值的准確度約為 0.01。 我也在測試集上嘗試過,結果仍然沒有加起來。
有人知道預測出了什么問題嗎?
先感謝您!
此致,
# Predict values using the test generator
predict = model.predict_generator(validation_generator,
steps=(validation_generator.n // 32)+1)
# Choose the highest scoring prediction
validate_df['prediction'] = np.argmax(predict, axis=-1)
# assigning label names to the corresponding indexes
labels = {0 : '1', 1 : '2', 2 : '3', 3 : '4', 4 : '5', 5 : '6', 6 : '7', 7 : '8', 8 : '9', 9 : '10',
10 : '11', 11 : '12', 12 : '13', 13 : '14', 14 : '15', 15 : '16', 16 : '17', 17 : '18', 18 : '19', 19 : '20',
20 : '21', 21 : '22', 22 : '23', 23 : '24', 24 : '25', 25 : '26', 26 : '27', 27 : '28', 28 : '29', 29 : '30',
30 : '31', 31 : '32', 32 : '33', 33 : '34', 34 : '35', 35 : '36', 36 : '37', 37 : '38', 38 : '39', 39 : '40',
40 : '41', 41 : '42', 42 : '43', 43 : '44', 44 : '45', 45 : '46', 46 : '47', 47 : '48', 48 : '49', 49 : '50',
50 : '51', 51 : '52', 52 : '53', 53 : '54', 54 : '55', 55 : '56', 56 : '57', 57 : '58', 58 : '59', 59 : '60',
60 : '61', 61 : '62', 62 : '63', 63 : '64', 64 : '65', 65 : '66', 66 : '67', 67 : '68', 68 : '69', 69 : '70',
70 : '71', 71 : '72', 72 : '73', 73 : '74', 74 : '75', 75 : '76', 76 : '77', 77 : '78', 78 : '79', 79 : '80'}
# Replace with original lables for comparison
validate_df['prediction'] = validate_df['prediction'].replace(labels)
# Print accuracy score and show output
print(accuracy_score(validate_df.label, validate_df.prediction))
test_df.head()
0.013067624959163672
img_name label prediction
0 train_27493.jpg 17 10
1 train_19980.jpg 60 58
2 train_4348.jpg 71 58
3 train_4141.jpg 11 4
4 train_21555.jpg 36 22
對於社區:
我的代碼缺少將值轉換回正確標簽的部分。
選擇得分最高的預測后,該值必須使用 class_indices 轉換回 train_generator
# Convert the predict category back into train_generator.class_indices
label_map = dict((v,k) for k,v in train_generator.class_indices.items())
validate_df['prediction'] = validate_df['prediction'].replace(label_map)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.