繁体   English   中英

如何在 Keras、.h5 保存文件中使用训练有素的 model 预测输入图像?

[英]How to predict input image using trained model in Keras, .h5 saved file?

我只是从 Keras 和机器学习开始。

我训练了一个 model 对来自 9 个类别的图像进行分类,并使用 model.save() 将其保存。 这是我使用的代码:

from keras.layers import Input, Lambda, Dense, Flatten
from keras.models import Model
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

# re-size all the images to this
IMAGE_SIZE = [224, 224]

train_path = 'Datasets/Train'
valid_path = 'Datasets/Test'

# add preprocessing layer to the front of resnet
resnet = ResNet50(input_shape=IMAGE_SIZE + [3], weights='imagenet', include_top=False)

# don't train existing weights
for layer in resnet.layers:
    layer.trainable = False

    # useful for getting number of classes
folders = glob('Datasets/Train/*')

# our layers - you can add more if you want
x = Flatten()(resnet.output)
# x = Dense(1000, activation='relu')(x)
prediction = Dense(len(folders), activation='softmax')(x)

# create a model object
model = Model(inputs=resnet.input, outputs=prediction)

# view the structure of the model
model.summary()

# tell the model what cost and optimization method to use
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   shear_range=0.1,
                                   zoom_range=0.1,
                                   horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1. / 255)

training_set = train_datagen.flow_from_directory('Datasets/Train',
                                                 target_size=(224, 224),
                                                 batch_size=32,
                                                 class_mode='categorical')

test_set = test_datagen.flow_from_directory('Datasets/Test',
                                            target_size=(224, 224),
                                            batch_size=32,
                                            class_mode='categorical')


# fit the model
r = model.fit_generator(
    training_set,
    validation_data=test_set,
    epochs=3,
    steps_per_epoch=len(training_set),
    validation_steps=len(test_set)
)
def plot_loss_accuracy(r):
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(1, 2, 1)
    ax.plot(r.history["loss"], 'r-x', label="Train Loss")
    ax.plot(r.history["val_loss"], 'b-x', label="Validation Loss")
    ax.legend()
    ax.set_title('cross_entropy loss')
    ax.grid(True)

    ax = fig.add_subplot(1, 2, 2)
    ax.plot(r.history["accuracy"], 'r-x', label="Train Accuracy")
    ax.plot(r.history["val_accuracy"], 'b-x', label="Validation Accuracy")
    ax.legend()
    ax.set_title('acuracy')
    ax.grid(True)

它训练成功。 为了在新图像上加载和测试这个 model,我使用了以下代码:

from keras.models import load_model
import cv2
import numpy as np

model = load_model('model.h5')

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

img = cv2.imread('test.jpg')
img = cv2.resize(img,(320,240))
img = np.reshape(img,[1,320,240,3])

classes = model.predict_classes(img)

print(classes)

它输出:

AttributeError:“模型”object 没有属性“predict_classes”

为什么连预测都不行?

谢谢,

predict_classes 仅适用于顺序 api http://faroit.com/keras-docs/1.0.0/models/sequential/

因此,您首先需要获取概率并将最大概率作为 class。

from keras.models import load_model
import cv2
import numpy as np


class_names = ['a', 'b', 'c', ...] # fill the rest

model = load_model('model.h5')

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

img = cv2.imread('test.jpg')
img = cv2.resize(img,(320,240))
img = np.reshape(img,[1,320,240,3])

classes = np.argmax(model.predict(img), axis = -1)

print(classes)

names = [class_names[i] for i in classes]

print(names)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM