简体   繁体   English

Keras对Python的预测

[英]Keras Predictions on python

I have this code for my CNN: 我的CNN有以下代码:

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K

# dimensions of our images.
img_width, img_height = 64, 64

train_data_dir = "path_trainning"
validation_data_dir = "path_validation"
nb_train_samples = 2000
nb_validation_samples = 800
epochs = 10
batch_size = 16

if K.image_data_format() == 'channels_first':
    input_shape = (3, img_width, img_height)
else:
    input_shape = (img_width, img_height, 3)

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

model.save('my_cnn.h5')

And this is the code for my predictions: 这是我的预测代码:

for file in os.listdir(targets_path):
    filef = '\\' + file
    test_image = image.load_img(targets_path + filef, target_size=(64, 64))
    test_image = image.img_to_array(test_image)
    test_image = np.expand_dims(test_image, axis=0)
    result = model.predict(test_image)
    print("\nOriginal: " + file)
    print("Prediction: " + str(result[0][0]))
    if result[0][0] == 1:
        prediction = 'dog'
    else:
        prediction = 'cat'
    print(prediction)

My question is: 我的问题是:

With this code as the "Prediction" part, I am realising that unless the CNN has a 1, it won't be a dog. 通过将此代码作为“预测”部分,我意识到,除非CNN的值为1,否则它不会是狗。 And I am getting results like 0.99999 is a cat, but with that value it is closer to be a dog. 而且我得到的结果像0.99999就是一只猫,但是有了这个值,它更像一只狗。

I think I am not understanding it properly. 我想我不太了解。

Could someone explain me please? 有人可以解释一下吗?

这是由于输出层是具有S型激活的节点,该节点返回0到1之间的值。因此,结果永远不会为1(或0),因此代码将始终返回“ cat”。

This might be the issue with your CNN. 这可能是您的CNN的问题。

You are using an ReLU activation in the hidden layers. 您正在隐藏层中使用ReLU激活。 They have an output range from 0 to Infinity. 它们的输出范围是0到无穷大。 When these values flow through the final output activation which is sigmoid in your case. 当这些值流过最终输出激活时(在您的情况下为S型)。 If I pass a greater value like 25 to sigmoid the output will be close to 1. The same would happen with a very small value which would result in threshold closer to 0. 如果我将一个较大的值(如25)传递给S形,则输出将接近1。如果值很小,也会发生相同的情况,这将导致阈值更接近0。

You should use a softmax function at the output layer if you are using ReLU in the hidden layers. 如果在隐藏层中使用ReLU,则应在输出层使用softmax函数。 Softmax converts the logits to class probabilities. Softmax将logit转换为类概率。

And also, with softmax, you would use categorical classes and not binary. 而且,对于softmax,您将使用分类类,而不是二进制。 You will have 2 classes and hence 2 output nodes. 您将有2个类,因此有2个输出节点。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM