簡體   English   中英

Keras ImageDataGenerator 預測超過預測集

[英]Keras ImageDataGenerator Predicts More Than The Prediction Set

我正在嘗試制作一個分類美國手語的 CNN model。 我已經創建並訓練了我的 model。 現在我正在嘗試預測課程。 我的預測集有 7250 個未標記的圖像,但是當我進行預測時,model 執行 587250 個預測,而我需要它來執行 7250 個預測。 我提供下面的代碼。 這是什么原因? 難道我做錯了什么?

代碼塊 1:

predict_set = data.flow_from_directory('/content/gdrive/MyDrive/test_data/')

Output 1:

Found 7250 images belonging to 1 classes.

代碼塊 2:

import numpy as np
predictions = model.predict_classes(predict_set)
print(len(predictions),"\n", predictions)

Output 2:

587250 
[27 27 27 ... 27  6 12]

編輯:

CNN Model:

model = Sequential()

# First Layer
model.add(Conv2D(filters = 64, kernel_size = (4, 4), input_shape = (64, 64, 3), activation = 'relu'))
model.add(Conv2D(filters = 64, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Second Layer
model.add(Conv2D(filters = 128, kernel_size = (4, 4), activation = 'relu'))
model.add(Conv2D(filters = 128, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Third Layer
model.add(Conv2D(filters = 256, kernel_size = (4, 4), activation = 'relu'))
model.add(Conv2D(filters = 256, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Flattening
model.add(Flatten())
model.add(Dropout(0.5))

model.add(Dense(units = 512, activation = 'relu')) # Hidden Layer
model.add(Dense(units = 29, activation = 'softmax')) # Output Layer

#Compiling the CNN
model.compile(optimizer= 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

model.fit_generator(training_set, steps_per_epoch = 350, epochs = 15, validation_data = test_set, validation_steps = 100)

test_set 和 training_set 的批大小為 64

Model 總結:

Model: "sequential_24"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_142 (Conv2D)          (None, 61, 61, 64)        3136      
_________________________________________________________________
conv2d_143 (Conv2D)          (None, 29, 29, 64)        65600     
_________________________________________________________________
dropout_90 (Dropout)         (None, 29, 29, 64)        0         
_________________________________________________________________
batch_normalization_69 (Batc (None, 29, 29, 64)        256       
_________________________________________________________________
conv2d_144 (Conv2D)          (None, 26, 26, 128)       131200    
_________________________________________________________________
conv2d_145 (Conv2D)          (None, 12, 12, 128)       262272    
_________________________________________________________________
dropout_91 (Dropout)         (None, 12, 12, 128)       0         
_________________________________________________________________
batch_normalization_70 (Batc (None, 12, 12, 128)       512       
_________________________________________________________________
conv2d_146 (Conv2D)          (None, 9, 9, 256)         524544    
_________________________________________________________________
conv2d_147 (Conv2D)          (None, 3, 3, 256)         1048832   
_________________________________________________________________
dropout_92 (Dropout)         (None, 3, 3, 256)         0         
_________________________________________________________________
batch_normalization_71 (Batc (None, 3, 3, 256)         1024      
_________________________________________________________________
flatten_21 (Flatten)         (None, 2304)              0         
_________________________________________________________________
dropout_93 (Dropout)         (None, 2304)              0         
_________________________________________________________________
dense_42 (Dense)             (None, 512)               1180160   
_________________________________________________________________
dense_43 (Dense)             (None, 29)                14877     
=================================================================
Total params: 3,232,413
Trainable params: 3,231,517
Non-trainable params: 896
_________________________________________________________________

predict_classes 的文檔非常有限。 如果您使用基於下面顯示的文檔的生成器,它可能無法工作

x: input data, as a Numpy array or list of Numpy arrays (if the model has multiple inputs).

所以我認為它不適用於發電機。 所以你必須使用 model.predict。

predictions=model.predict(predict_set)
for p in predictions:
    class_index=np.argmax(p) # this is the integer value assogned to a class.

如果您使用生成器來訓練您的 model(我將其稱為 train_gen),那么您可以獲得 for (class name, class_index) 的 class_indices 字典,如下所示

class_dict=train_gen.class_indices
# reverse the dictionary
for key,value in class_dict.items():
        new_dict[value]=key  
````
now you can use the new_dict to get the class name with the code below
```` 
for p in predictions:
    class_index=np.argmax(p)
    class_name=new_dict[class_index]
````

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM