[英]Tensorflow Macro F1 Score for multiclass and also for binary classification
我正在尝试训练 2 个 1D Conv 神经网络 - 一个用于多类分类问题,第二个用于二元分类问题。 我的指标之一必须是这两个问题的 Macro F1 分数。 但是我在使用来自 tensorflow 插件的tfa.metrics.F1Score
时遇到问题。
我有 3 个类编码为 0、1、2。
网络的最后一层和编译方法如下所示( int_sequeces_input
是输入层):
preds = layers.Dense(3, activation="softmax")(x)
model = keras.Model(int_sequences_input, preds)
f1_macro = F1Score(num_classes=3, average='macro')
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy',f1_macro])
但是,当我运行model.fit()
时,出现以下错误:
ValueError: Dimension 0 in both shapes must be equal, but are 3 and 1. Shapes are [3] and [1]. for '{{node AssignAddVariableOp_7}} = AssignAddVariableOp[dtype=DT_FLOAT](AssignAddVariableOp_7/resource, Sum_6)' with input shapes: [], [1].
X_train - (23658, 150)
y_train - (23658,)
我有 2 个类编码为 0,1
网络的最后一层和编译方法如下所示( int_sequeces_input
是输入层):
preds = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(int_sequences_input, preds)
print(model.summary())
f1_macro = F1Score(num_classes=2, average='macro')
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',f1_macro])
同样,当我运行model.fit()
时出现错误:
ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1. Shapes are [2] and [1]. for '{{node AssignAddVariableOp_4}} = AssignAddVariableOp[dtype=DT_FLOAT](AssignAddVariableOp_4/resource, Sum_3)' with input shapes: [], [1].
X_train - (15770, 150)
y_train - (15770,)
所以我的问题是:如何使用宏观 F1 分数评估我的两个模型? 如何修复我的实现以使其与tfa.metrics.F1Score
一起使用? 或者有没有其他方法可以在不使用tfa.metrics.F1Score
的情况下计算宏观 F1 分数? 谢谢。
从其文档页面查看用法示例。
metric = tfa.metrics.F1Score(num_classes=3, threshold=0.5)
y_true = np.array([[1, 1, 1],
[1, 0, 0],
[1, 1, 0]], np.int32)
y_pred = np.array([[0.2, 0.6, 0.7],
[0.2, 0.6, 0.6],
[0.6, 0.8, 0.0]], np.float32)
metric.update_state(y_true, y_pred)
您可以看到它希望标签采用 one-hot 格式。
但是考虑到您上面提到的形状:
shapes of data:
X_train - (23658, 150)
y_train - (23658,)
看起来您的标签采用索引格式。 尝试使用tf.one_hot(y_train, num_classes)
将它们转换为一个热点。 您还需要将损失更改为loss='categorical_crossentropy'
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.