简体   繁体   English

使用经过训练的 TensorFlow model 预测单个 PNG 图像

[英]Predicting a single PNG image using a trained TensorFlow model

import tensorflow as tf
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = (28,28)),
    tf.keras.layers.Dense(128, activation = 'relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
]) 

This is the code for the model, which I have trained using the mnist dataset.这是我使用 mnist 数据集训练的 model 的代码。 What I want to do is to then pass a 28x28 png image to the predict() method, which is not working.我想要做的是然后将 28x28 png 图像传递给 predict() 方法,该方法不起作用。 The code for the prediction is:预测代码如下:

img = imageio.imread('image_0.png')
prediction = model.predict(img, batch_size = 1)

which produces the error产生错误

ValueError: Error when checking input: expected flatten_input to have shape (28, 28) but got array with shape (28, 3)

I have been stuck on this problem for a few days, but I can't find the correct way to pass an image into the predict method.我被这个问题困扰了几天,但我找不到将图像传递给预测方法的正确方法。 Any help?有什么帮助吗?

Predict function makes predictions over a batch of image. Predict function 对一批图像进行预测。 You should include batch dimension (first dimension) to your img, even to predict a single example.您应该在您的 img 中包含批量维度(第一维度),甚至可以预测单个示例。 You need something like this:你需要这样的东西:

img = imageio.imread('image_0.png')
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)

As @desertnaut says, seems you are using a RGB image, so your first layer should use input_shape = (28,28,3) .正如@desertnaut 所说,您似乎正在使用 RGB 图像,因此您的第一层应该使用input_shape = (28,28,3) Therefore, img parameter of predict function should have (1,28,28,3) shape.因此,预测 function 的img参数应该具有 (1,28,28,3) 形状。

In your case, img parameter of predict function has (28,28,3) shape, thus predict function took the first dimension as number of images, and could not match the other two dimensions to the input_shape of the first layer.在您的情况下,预测 function 的img参数具有 (28,28,3) 形状,因此预测 function 将第一个维度作为图像数量,并且无法将其他两个维度与第一层的input_shape匹配。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM