简体   繁体   中英

How to predict a single sample with Keras

I'm trying to implement a Fully Convolutional Neural Network and can successfully test the accuracy of the model on the test set after training. However, I'd like to use the model to make a prediction on a single sample only. Training was in batches. I believe what I'm missing is related to batch size and input shape. Here is the configuration for the network:

def read(file_name):
    data = np.loadtxt(file_name, delimiter="\t")
    y = data[:, 0]
    x = data[:, 1:]
    return x, y.astype(int)

train_data, train_labels = read("FordA_TRAIN.tsv")
test_data, test_labels = read("FordA_TEST.tsv")

train_data = train_data.reshape((train_data.shape[0], train_data.shape[1], 1))
test_data = test_data.reshape((test_data.shape[0], test_data.shape[1], 1))

num_classes = len(np.unique(train_labels))

#print(train_data[0])

# Shuffle the data to prepare for validation_split (and prevent overfitting for class order)

idx = np.random.permutation(len(train_data))
train_data = train_data[idx]
train_labels = train_labels[idx]

#Standardize labels to have a value between 0 and 1 rather than -1 and 1.

train_labels[train_labels == -1] = 0
test_labels[test_labels == -1] = 0

def make_model(input_shape):
    input_layer = keras.layers.Input(input_shape)

    conv1 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(input_layer)
    conv1 = keras.layers.BatchNormalization()(conv1)
    conv1 = keras.layers.ReLU()(conv1)

    conv2 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv1)
    conv2 = keras.layers.BatchNormalization()(conv2)
    conv2 = keras.layers.ReLU()(conv2)

    conv3 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv2)
    conv3 = keras.layers.BatchNormalization()(conv3)
    conv3 = keras.layers.ReLU()(conv3)

    gap = keras.layers.GlobalAveragePooling1D()(conv3)

    output_layer = keras.layers.Dense(num_classes, activation="softmax")(gap)

    return keras.models.Model(inputs=input_layer, outputs=output_layer)


model = make_model(input_shape=train_data.shape[1:])
keras.utils.plot_model(model, show_shapes=True)

epochs = 500
batch_size = 32

callbacks = [
    keras.callbacks.ModelCheckpoint(
        "best_model.h5", save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=20, min_lr=0.0001
    ),
    keras.callbacks.EarlyStopping(monitor="val_loss", mode = 'min', patience=50, verbose=1),
]
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)
history = model.fit(
    train_data,
    train_labels,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks,
    validation_split=0.2,
    verbose=1,
)

model = keras.models.load_model("best_model.h5")

test_loss, test_acc = model.evaluate(test_data, test_labels)

print("Test accuracy", test_acc)
print("Test loss", test_loss)

The above code can successfully display where the accuracy converged. Now, I'd like to make predictions on single samples. So far I have:

def read(file_name):
    data = np.loadtxt(file_name, delimiter="\t")
    y = data[:, 0]
    x = data[:, 1:]
    return x, y.astype(int)

test_data, test_labels = read("FordA_TEST_B.tsv")
test_data = test_data.reshape((test_data.shape[0], test_data.shape[1], 1))

test_labels[test_labels == -1] = 0

print(test_data)

model = keras.models.load_model("forda_original_model.h5")

q = model.predict(test_data[0])

This raises the error: ValueError: Error when checking input: expected input_1 to have 3 dimensions, but got array with shape (500, 1)

How does the input have to be reshaped and what is the rule to go by? Any help is much appreciated!

Copied from a comment:

The model expects a batch dimension. Thus, to predict for a single model, just expand the dimensions to create a single-sized batch by running:

q = model.predict(test_data[0][None,...])

or

q = model.predict(test_data[0][np.newaxis,...])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM