[英]Use different metrics in tf.keras.metrics for mutli-classification model
I am using the TensorFlow federated framework for a multiclassification problem.我正在使用 TensorFlow 联合框架来解决多分类问题。 I am following the tutorials and most of them use the metric (
tf.keras.metrics.SparseCategoricalAccuracy
) to measure the models' accuracy.我正在关注教程,其中大多数使用指标(
tf.keras.metrics.SparseCategoricalAccuracy
)来衡量模型的准确性。 I wanted to explore the other measures like (AUC, recall, F1, and precision) but I am getting the errors.我想探索其他指标,例如(AUC、召回、F1 和精度),但我得到了错误。 The code and the error message are provided below.
下面提供了代码和错误消息。
def create_keras_model():
initializer = tf.keras.initializers.Zeros()
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(8,)),
tf.keras.layers.Dense(64),
tf.keras.layers.Dense(4, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
tf.keras.metrics.Recall()]
)
The error错误
ValueError: Shapes (None, 4) and (None,) are incompatible
Is it because of the muti classification problem, that we cannot use these measures with it?是因为多分类问题,我们不能使用这些措施吗? and if so, is there any other metric I may use to measure my multi-classification model.
如果是这样,我是否可以使用任何其他指标来衡量我的多分类模型。
tf.keras.metrics.SparseCategoricalAccuracy() --> is for SparseCategorical (int) class. tf.keras.metrics.SparseCategoricalAccuracy() --> 用于 SparseCategorical (int) 类。 tf.keras.metrics.Recall() --> is for categorical (one-hot) class.
tf.keras.metrics.Recall() --> 用于分类(one-hot)类。
You have to use a one-hot class if you want to use any metric naming without the 'Sparse'.如果要使用没有“稀疏”的任何度量命名,则必须使用 one-hot 类。
update:更新:
num_class=4
def get_img_and_onehot_class(img_path, class):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
""" Other preprocessing of image."""
return img, tf.one_hot(class, num_class)
when you got the one-hot class:当你上了一堂热课时:
loss=tf.losses.CategoricalCrossentropy
METRICS=[tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),]
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=loss,
metrics= METRICS)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.