简体   繁体   English

在 tf.keras.metrics 中使用不同的指标进行多分类模型

[英]Use different metrics in tf.keras.metrics for mutli-classification model

I am using the TensorFlow federated framework for a multiclassification problem.我正在使用 TensorFlow 联合框架来解决多分类问题。 I am following the tutorials and most of them use the metric ( tf.keras.metrics.SparseCategoricalAccuracy ) to measure the models' accuracy.我正在关注教程,其中大多数使用指标( tf.keras.metrics.SparseCategoricalAccuracy )来衡量模型的准确性。 I wanted to explore the other measures like (AUC, recall, F1, and precision) but I am getting the errors.我想探索其他指标,例如(AUC、召回、F1 和精度),但我得到了错误。 The code and the error message are provided below.下面提供了代码和错误消息。

def create_keras_model():
  initializer = tf.keras.initializers.Zeros()
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(8,)),
      tf.keras.layers.Dense(64),
      tf.keras.layers.Dense(4, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
               tf.keras.metrics.Recall()]
      )

The error错误

ValueError: Shapes (None, 4) and (None,) are incompatible

Is it because of the muti classification problem, that we cannot use these measures with it?是因为多分类问题,我们不能使用这些措施吗? and if so, is there any other metric I may use to measure my multi-classification model.如果是这样,我是否可以使用任何其他指标来衡量我的多分类模型。

tf.keras.metrics.SparseCategoricalAccuracy() --> is for SparseCategorical (int) class. tf.keras.metrics.SparseCategoricalAccuracy() --> 用于 SparseCategorical (int) 类。 tf.keras.metrics.Recall() --> is for categorical (one-hot) class. tf.keras.metrics.Recall() --> 用于分类(one-hot)类。

You have to use a one-hot class if you want to use any metric naming without the 'Sparse'.如果要使用没有“稀疏”的任何度量命名,则必须使用 one-hot 类。

update:更新:

num_class=4
def get_img_and_onehot_class(img_path, class):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_jpeg(img, channels=3)
    """ Other preprocessing of image."""
    return img, tf.one_hot(class, num_class)

when you got the one-hot class:当你上了一堂热课时:

loss=tf.losses.CategoricalCrossentropy
METRICS=[tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),]

model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss=loss,
        metrics= METRICS)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Tensorflow:如何在多类分类中使用 tf.keras.metrics? - Tensorflow: How to use tf.keras.metrics in multiclass classification? 如何在 tf.keras.metrics 中找到误报率 - how to find false positive rate in tf.keras.metrics 在Keras使用tf.metrics? - Use tf.metrics in Keras? 使用度量“tf.keras.metrics.AUC”进行训练后,如何在 Keras 中使用“load_model”? - How to use "load_model" in Keras after trained using as metric "tf.keras.metrics.AUC"? 了解 tf.keras.metrics.Precision and Recall 进行多类分类 - Understanding tf.keras.metrics.Precision and Recall for multiclass classification 使用在 tf.keras 中实现的自定义指标加载 keras model - Loading keras model with custom metrics implemented in tf.keras 带有 TF 后端的 Keras 指标与 tensorflow 指标 - Keras metrics with TF backend vs tensorflow metrics 何时在分类器神经网络 model 中使用“准确度”字符串或 tf.keras.metrics.Accuracy() - When to use "accuracy" string or tf.keras.metrics.Accuracy() in a classifier neural network model 我可以在多类分类问题中使用 tf.metrics.BinaryAccuracy 吗? - Can I use tf.metrics.BinaryAccuracy in a multiclass classification problem? 可以传递给 tf.keras.model.compile 的指标列表 - List of metrics that can be passed to tf.keras.model.compile
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM