简体   繁体   中英

Use different metrics in tf.keras.metrics for mutli-classification model

I am using the TensorFlow federated framework for a multiclassification problem. I am following the tutorials and most of them use the metric ( tf.keras.metrics.SparseCategoricalAccuracy ) to measure the models' accuracy. I wanted to explore the other measures like (AUC, recall, F1, and precision) but I am getting the errors. The code and the error message are provided below.

def create_keras_model():
  initializer = tf.keras.initializers.Zeros()
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(8,)),
      tf.keras.layers.Dense(64),
      tf.keras.layers.Dense(4, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
               tf.keras.metrics.Recall()]
      )

The error

ValueError: Shapes (None, 4) and (None,) are incompatible

Is it because of the muti classification problem, that we cannot use these measures with it? and if so, is there any other metric I may use to measure my multi-classification model.

tf.keras.metrics.SparseCategoricalAccuracy() --> is for SparseCategorical (int) class. tf.keras.metrics.Recall() --> is for categorical (one-hot) class.

You have to use a one-hot class if you want to use any metric naming without the 'Sparse'.

update:

num_class=4
def get_img_and_onehot_class(img_path, class):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_jpeg(img, channels=3)
    """ Other preprocessing of image."""
    return img, tf.one_hot(class, num_class)

when you got the one-hot class:

loss=tf.losses.CategoricalCrossentropy
METRICS=[tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),]

model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss=loss,
        metrics= METRICS)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM