繁体   English   中英

Keras ImageDataGenerator 预测超过预测集

[英]Keras ImageDataGenerator Predicts More Than The Prediction Set

我正在尝试制作一个分类美国手语的 CNN model。 我已经创建并训练了我的 model。 现在我正在尝试预测课程。 我的预测集有 7250 个未标记的图像,但是当我进行预测时,model 执行 587250 个预测,而我需要它来执行 7250 个预测。 我提供下面的代码。 这是什么原因? 难道我做错了什么?

代码块 1:

predict_set = data.flow_from_directory('/content/gdrive/MyDrive/test_data/')

Output 1:

Found 7250 images belonging to 1 classes.

代码块 2:

import numpy as np
predictions = model.predict_classes(predict_set)
print(len(predictions),"\n", predictions)

Output 2:

587250 
[27 27 27 ... 27  6 12]

编辑:

CNN Model:

model = Sequential()

# First Layer
model.add(Conv2D(filters = 64, kernel_size = (4, 4), input_shape = (64, 64, 3), activation = 'relu'))
model.add(Conv2D(filters = 64, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Second Layer
model.add(Conv2D(filters = 128, kernel_size = (4, 4), activation = 'relu'))
model.add(Conv2D(filters = 128, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Third Layer
model.add(Conv2D(filters = 256, kernel_size = (4, 4), activation = 'relu'))
model.add(Conv2D(filters = 256, kernel_size = (4, 4), strides = 2,  activation = 'relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization(axis = 3, momentum = 0.8))

# Flattening
model.add(Flatten())
model.add(Dropout(0.5))

model.add(Dense(units = 512, activation = 'relu')) # Hidden Layer
model.add(Dense(units = 29, activation = 'softmax')) # Output Layer

#Compiling the CNN
model.compile(optimizer= 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

model.fit_generator(training_set, steps_per_epoch = 350, epochs = 15, validation_data = test_set, validation_steps = 100)

test_set 和 training_set 的批大小为 64

Model 总结:

Model: "sequential_24"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_142 (Conv2D)          (None, 61, 61, 64)        3136      
_________________________________________________________________
conv2d_143 (Conv2D)          (None, 29, 29, 64)        65600     
_________________________________________________________________
dropout_90 (Dropout)         (None, 29, 29, 64)        0         
_________________________________________________________________
batch_normalization_69 (Batc (None, 29, 29, 64)        256       
_________________________________________________________________
conv2d_144 (Conv2D)          (None, 26, 26, 128)       131200    
_________________________________________________________________
conv2d_145 (Conv2D)          (None, 12, 12, 128)       262272    
_________________________________________________________________
dropout_91 (Dropout)         (None, 12, 12, 128)       0         
_________________________________________________________________
batch_normalization_70 (Batc (None, 12, 12, 128)       512       
_________________________________________________________________
conv2d_146 (Conv2D)          (None, 9, 9, 256)         524544    
_________________________________________________________________
conv2d_147 (Conv2D)          (None, 3, 3, 256)         1048832   
_________________________________________________________________
dropout_92 (Dropout)         (None, 3, 3, 256)         0         
_________________________________________________________________
batch_normalization_71 (Batc (None, 3, 3, 256)         1024      
_________________________________________________________________
flatten_21 (Flatten)         (None, 2304)              0         
_________________________________________________________________
dropout_93 (Dropout)         (None, 2304)              0         
_________________________________________________________________
dense_42 (Dense)             (None, 512)               1180160   
_________________________________________________________________
dense_43 (Dense)             (None, 29)                14877     
=================================================================
Total params: 3,232,413
Trainable params: 3,231,517
Non-trainable params: 896
_________________________________________________________________

predict_classes 的文档非常有限。 如果您使用基于下面显示的文档的生成器,它可能无法工作

x: input data, as a Numpy array or list of Numpy arrays (if the model has multiple inputs).

所以我认为它不适用于发电机。 所以你必须使用 model.predict。

predictions=model.predict(predict_set)
for p in predictions:
    class_index=np.argmax(p) # this is the integer value assogned to a class.

如果您使用生成器来训练您的 model(我将其称为 train_gen),那么您可以获得 for (class name, class_index) 的 class_indices 字典,如下所示

class_dict=train_gen.class_indices
# reverse the dictionary
for key,value in class_dict.items():
        new_dict[value]=key  
````
now you can use the new_dict to get the class name with the code below
```` 
for p in predictions:
    class_index=np.argmax(p)
    class_name=new_dict[class_index]
````

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM