简体   繁体   中英

Mean Squared Error for image classification model in TensorFlow

I am trying to teach image classification model to define a number characteristic from an image. I am sure that SparseCategoricalCrossentropy loss function doesn't work for me, as for training I need to penalize big differences more than small ones. Ideally I would like to use Mean Squared Error loss function.

I use TensorFlow tutorial to prepare the model - https://www.tensorflow.org/tutorials/images/classification .

Class names are numbers for me, I tried the following options:

  • ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']
  • ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12']

The only change I made against tutorial (except the dataset) is exchanging SparseCategoricalCrossentropy loss function to 'mean_squared_error'.

But the loss function clearly doesn't work for me. It returns values, that gets smaller with training, but accuracy is never more than 5%, and it even goes down as loss value becomes smaller. Results also do not make sense. The data is fine, I can easily achieve 95% accuracy with SparseCategoricalCrossentropy loss function. What am I missing?

UPDATE: I think what I really need is a way to define regression problem in TensorFlow using images labeled with numbers.

Turns out it is quite easy to turn image classification problem into a regression problem. Against tutorial referenced in question I had to make the following changes:

  1. Different dataset with numbers as 'classes' (folder names).

  2. Changed loss function to Mean Squared Error or other loss function suitable for regression.

  3. Made the last layer for model with just 1 neurone instead of number of classes (and without softmax):

     ... layers.Dense(128, activation='relu'), layers.Dense(1) # changed from num_classes to 1
  4. Changed interpretation of prediction results:

     ... predictions = model.predict(img_array) # score = tf.nn.softmax(predictions[0]) # correct for classification, but not regression score = predictions.flatten()[0] # correct result for regression ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM