简体   繁体   中英

TensorFlow dense layer input data shape for MNIST

I trained a neural network on the MNIST dataset, following the exact code from Tensorflow . Now I want to predict on an image I created using the model, but there seems to be some issue with the input size. I used openCV to read the image and convert it to [28,28,1] (greyscale), which I think is the same input as the model is trained on from tf_datasets.

img = cv2.imread("168.png", 0)  # uint8 format
img = cv2.resize(img, (28,28))  # scaling down to input data dimensions
img = tf.cast(img, tf.float32)/255. # uint8 -> normalizing to tf.float32 dtype as input
img.shape
>>> TensorShape([28, 28])

On inspecting the ds_train with ds_train.element_spec

>>> (TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int64, name=None))

But this is the error I get with model.predict(img) :

WARNING:tensorflow:Model was constructed with shape (None, 28, 28) for input Tensor("flatten_input:0", shape=(None, 28, 28), dtype=float32), but it was called on an input with incompatible shape (None, 28).
WARNING:tensorflow:Model was constructed with shape (None, 28, 28) for input Tensor("flatten_input:0", shape=(None, 28, 28), dtype=float32), but it was called on an input with incompatible shape (None, 28).

ValueError: Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 784 but received input with shape [None, 28]

Any idea how to fix the shape? I'm using tf.__version__ >>> 2.3.1 .

You will need to resize it to the shape [1,28,28,1] and not [28,28] . Also you do not need to use OpenCV for this and can instead just use tf.reshape() .

img = tf.reshape(img, [1, 28, 28, 1])

After training the model, when you try to inference on the real image, then you need to preprocess it first. In your case, ds_train gives the input shape of (None, 28, 28, 1) - which means your input image is in 28 width and height, and 1 refers to it as a grayscale image and lastly None simply refers an unknown batch size here. So, when you load an image for prediction (after training), first you must load it as grayscale. Next, you have to resize it according to the input size (which is 28 for your case). Next, expand a new axis for the channel axis which will make 28 x 28 input to 28 x 28 x 1 . And last use reshape function @yudhesh already suggested to add batch dimension.

Sample Code

Here is a complete example:

import tensorflow as tf 

(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)

input = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(10, 3, activation="relu")(input)
output =  tf.keras.layers.GlobalMaxPooling2D()(x)
model =  tf.keras.Model(input, output)

# compile 
model.compile(
          loss      = tf.keras.losses.CategoricalCrossentropy(),
          metrics   = tf.keras.metrics.CategoricalAccuracy(),
          optimizer = tf.keras.optimizers.Adam())


# history is now a local variable 
model.fit(x_train, y_train, batch_size=512, epochs=10, verbose = 2)

Inference

Let's define a function that will do some preprocessing,

# a preprocess function 
def infer_prec(img, img_size):
    img = tf.expand_dims(img, -1)          # from 28 x 28 to 28 x 28 x 1 
    img = tf.divide(img, 255)              # normalize 
    img = tf.image.resize(img,             # resize acc to the input
             [img_size, img_size])
    img = tf.reshape(img,                  # reshape to add batch dimension 
            [1, img_size, img_size, 1])
    return img 

img = cv2.imread('1.png', 0)   # read image as gray scale       
print(img.shape)   # (357, 349)  

img = infer_prec(img, 28)  # call preprocess function 
print(img.shape)   # (1, 28, 28, 1)

y_pred = model.predict(img)
y_pred  # probabilities 
# array([[0.        , 0.        , 0.        , 0.        , 0.0304523 ,
        0.        , 0.02225076, 0.        , 0.18298316, 0.        ]],
      dtype=float32)

# get predicted label 
tf.argmax(y_pred, axis=-1).numpy() # array([8], dtype=int64)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM