简体   繁体   中英

Poor model performances when doing multi-class classification

Context

I have a dataset of medical X-Rays ( example ). I want to train a model to recognize an overbite . The potential values can be:

  • Normal
  • 1-2mm
  • 2-4mm
  • [... ]
  • 8mm+

Test Results

I've built a CNN to process the images. My problem is that the validation accuracy is extremely low when comparing multiple classes of images. I tried different combinations of things and here are the results:

| Image       | Val Accuracy |
| ----------- | ------------ |
| A -> B      | 56%          |
| B -> C      | 33%          |
| A -> C      | 75%          |
| A -> B -> C | 17%          |

When I compare images 1-1 against each other, the ML seems to train better than when otherwise. Why is that the case? In total I have:

  • 1368 images of A
  • 1651 images of B
  • 449 images of C

(I realize 3.5K images is not a lot of data but I'm trying to figure out the fundamentals of a good model first before downloading and training on more data. My DB only has 17K images)

Code

I have a custom input pipeline and generate a tf.data.Dataset .

print(train_ds)
==> <ParallelMapDataset element_spec=(TensorSpec(shape=(512, 512, 1), dtype=tf.float32, name=None), TensorSpec(shape=(3,), dtype=tf.uint8, name=None))>

Here is the CNN architecture:

input_shape = (None, IMG_SIZE, IMG_SIZE, color_channels)
num_classes = len(class_names)

# Pre-processing layers
RESIZED_IMG = 256
resize_and_rescale = tf.keras.Sequential([
  layers.Resizing(RESIZED_IMG, RESIZED_IMG),
  layers.Rescaling(1./255)
])

medium = 0.2
micro = 0.10
data_augmentation = tf.keras.Sequential([
  layers.RandomContrast(medium),
  layers.RandomBrightness(medium),
  layers.RandomRotation(micro, fill_mode="constant"),
  layers.RandomTranslation(micro, micro, fill_mode="constant"),
  layers.RandomZoom(micro, fill_mode="constant"),
])

# Hidden layers
model = Sequential([
  data_augmentation,
  resize_and_rescale,

  Conv2D(16, 3, padding='same', activation='relu'),
  Conv2D(24, 5, padding='same', activation='relu'),
  MaxPooling2D(),

  Flatten(),
  Dense(128, activation='relu'),
  Dense(num_classes, activation='softmax'), 
])

# Build
model.compile(
  optimizer='adam',
  loss='categorical_crossentropy',
  metrics=['accuracy'])

model.build(input_shape)
model.summary()

# Start training
epochs = 15
early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    restore_best_weights=True,
    patience=7
)
mcp_save = tf.keras.callbacks.ModelCheckpoint(
    '.mdl_wts.hdf5', 
    save_best_only=True, 
    monitor='val_accuracy'
)

history = model.fit(
  batch_train_ds,
  validation_data=batch_val_ds,
  epochs=epochs,
  class_weight=class_weights,
  callbacks=[early_stopping_monitor, mcp_save]
)

The only thing I've changed in between my runs is which images are loaded in my input pipeline and then recorded their accuracy. I have intentionally kept the CNN small because I don't have a lot of data.

Questions

  • Why does my model perform worse when training on more classes?
  • Do I have the wrong data and the images do not have enough conclusive information?
  • Is my image count too low in order to train a decent ML model?
  • Is my CNN not deep enough for multi-class classification?

Model capacity

The model is too simple (just two convolutional layers) to process these images with complex structures. The two convolutional layers may be able to recognize simple structures like lines and nothing else more complicated.

The model does not have the capacity to model more complicated structures, such as teeth and different parts of the skull, to decide about the overbite. It simply does not understand the image. Compare the size of your model to models such as the ResNet. You will see that it is significantly larger.

Data

As you said, your dataset is relatively small. Even if you use a model with sufficient capacity to understand the images, your dataset might be insufficient for the model to generalize and instead will overfit your data. You recognize this by watching the metric for your training and validation data: the training score will keep increasing, while the validation score will not. That is called overfitting.

Practical standpoint

I would suggest that you use a pre-trained model. These models were trained on massive datasets with millions of images, so their ability to generalize is pretty high. You download the model and f.netune it on your dataset. See the Tensorflow tutorial on how to use a pre-trained model and fine-tune it.

Starting with a pre-trained model is always better than starting from scratch. Only in case of poor results would I move on to training a custom model.

Conclusion

I believe the dataset size is sufficient if you use a pre-trained model and fine-tune it. But if you have more data available, I would consider using it. The model will give you better results.

There might be many other reasons why the model could perform the it is performing but as for your first question, we can probably say that your model is performing poorly because it is unable to learn all the complexities of multiple classes due to the amount of data which is being provided to it. For a multi class classifier it would be better if you add more data so that it could learn learn the added complexity.

When you make a binary classifier it's easier for the model to learn distinguishing features from one another the same cannot be said for a multi class classifier.

The number of images you have for data might not be large enough for a multi classification model. The model could be overfitting on the train dataset which is why you get such bad accuracy on your validation dataset. Moreover the solution is not just to keep on adding images in the dataset but to balance out your dataset, having little to no imbalance in your dataset might give your model some space to improve.

The model architecture is quite small and not suitable for a multi class classifier. The more data you pore into your model the larger the architecture would be required to capture further complex features from the dataset specially in a multi classification model.

  1. Why does my model perform worse when training on more classes?

Answer: No, more classes doesn't effect the performance of the model

  1. Do I have the wrong data and the images do not have enough conclusive information?

Answer: yes, you need to balance your data since your data is unbalanced you can do: Random undersampling, Random oversampling, (synthetic Minority)(merging samples)

  1. Is my image count too low in order to train a decent ML model?

Answer: yes, Consider data agumentation

  1. Is my CNN not deep enough for multi-class classification?

Answer: you can go deep and it will improve accuracy then accuracy will be decreased again. so consider using Transfer learning with Re.net

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM