簡體   English   中英

CNN模型僅預測1個類別

[英]CNN Model is predicting only 1 class

我已經為二進制分類編碼了CNN模型。 我的數據集有偏差(56000張圖像的類別1和3000張圖像的類別2)。 我正在測試108張圖像(每個班級54張)。 我的模型將每個圖像預測為第1類。您能告訴我我的模型出了什么問題以及如何改善它嗎?

IMG_SIZE = 32
LR = 1e-1


convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 3], name='input')

convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 128, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)

convnet = fully_connected(convnet, 2, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')

model = tflearn.DNN(convnet, tensorboard_dir='/home/anas/Argentinadata/log')






train = training_data[:50000]
test =training_data[50000:]




X = np.array([i[0] for i in train]).reshape(-1,IMG_SIZE,IMG_SIZE,3)
Y = [i[1] for i in train]

test_x = np.array([i[0] for i in test]).reshape(-1,IMG_SIZE,IMG_SIZE,3)

test_y = [i[1] for i in test]



print(len(test_x))


print(len(X))

print(len(Y))


model.fit({'input': X}, {'targets': Y}, n_epoch=25, validation_set=({'input': test_x}, {'targets': test_y}),
            snapshot_step=500, show_metric=True, run_id=MODEL_NAME)

這僅意味着該模型發現最容易預測一個類別並獲得〜95%的准確性。 嘗試進行權重訓練或復制2類的圖像,直到1類和2類的分割比例約為50/50。

在這種情況下,我將建議兩件事(更直接的事情):

1.使用batchsize = 30-40

  1. 整理數據以在每個批次中包括更好的多樣性

您可以通過tflearn為此目的使用data_utils:

tflearn.data_utils.shuffle (your_array)

您可以在這里查看文檔

您可以做的另一件事是對數據進行重新采樣以包括更多的少數類,但是您沒有足夠大的父樣本。 因此,您可以嘗試手動批量插入少數類的數據。

您還可以在少數派類別上訓練生成模型並逐步生成一些數據點,但可能會導致少數派類別過擬合

更復雜的解決方法是修改成本函數,方法是在少數群體類別的預測錯誤的情況下,分配更高的成本。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM