简体   繁体   中英

Why CNN in Python is performing much worse than in Matlab?

I have trained a CNN in Matlab 2019b that does a binary classification. When this CNN was tested in a test dataset it was getting about 95% accuracy. I used the exportONNXNetwork function so that I can implement my CNN in Tensorflow, Keras. This is the code I am using to use the ONNX file in keras:

import onnx
from onnx_tf.backend import prepare
import numpy as np
from numpy import array
from IPython.display import display
from PIL import Image

onnx_model = onnx.load("model.onnx")
tf_rep = prepare(onnx_model)
img = Image.open("image.jpg").resize((224,224))
img = array(img).reshape(1,3,224,224)
img = img.astype(np.uint8)

classification = tf_rep.run(img)
print(classification)

When this python code was tested on the same test dataset it was classifying almost everything as class 0 with a few cases of class 1 . I am not sure why this is happening.

At a glance, I think you need to permute the image axes rather than reshape:

img = Image.open("image.jpg").resize((224,224))
img = array(img).transpose(2, 0, 1)
img = np.expand_dims(img, 0)

The image you get from PIL is in the channels last format, ie a tensor of shape (height, width, channels) , in this case (224, 224, 3) . Your model expects the input in the channels first format, ie a tensor of shape (channels, height, width) , in this case (3, 224, 224) .

You need to move the last axis to the front. If you use reshape, NumPy will traverse the array in C order (last axis index changing the fastest), meaning your image will end up scrambled. This is easier to understand on an example:

>>> img = np.arange(48).reshape(4, 4, 3)
>>> img[0, 0, :]
array([0, 1, 2])

The RGB values of the (0, 0) pixel are (0, 1, 2). If you use np.transpose() , this is preserved:

>>> img.transpose(2, 0, 1)[:, 0, 0]
array([0, 1, 2])

If you use reshape, your image will get scrambled:

>>> img.reshape(3, 224, 224)[:, 0, 0]
array([0, 16, 32])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM