![](/img/trans.png)
[英]Understanding tf.keras.metrics.Precision and Recall for multiclass classification
[英]Does following code give recall for multiclass classification in Keras?
以下代碼是否為 Keras 中的多類分類提供召回? 即使我在 model.compile 中調用召回 function 時沒有傳遞 y_true 和 y_pred,它也向我展示了召回的結果。
def recall(y_true, y_pred):
y_true = K.ones_like(y_true)
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (all_positives + K.epsilon())
return recall
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=[recall])
是的,它有效,因為在model.fit()
內部多次調用召回,指定這些值。
它的工作方式與此類似(更復雜和優化):
accuracy = tf.keras.metrics.CategoricalAccuracy()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
# Iterate over the batches of a dataset.
for step, (x, y) in enumerate(dataset):
with tf.GradientTape() as tape:
logits = model(x)
# Compute the loss value for this batch.
loss_value = loss_fn(y, logits)
# Update the state of the `accuracy` metric.
accuracy.update_state(y, logits)
# Update the weights of the model to minimize the loss value.
gradients = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
# Logging the current accuracy value so far.
if step % 100 == 0:
print('Step:', step)
print('Total running accuracy so far: %.3f' % accuracy.result())
這稱為漸變磁帶,可用於執行自定義的訓練循環。 基本上,它公開了在 model 的可訓練張量上計算的梯度。 它允許您手動更新 model 的權重,因此對於個性化非常有用。 所有這些東西也在model.fit()
內部自動完成。 你不需要這個,它只是為了解釋事情是如何工作的。
如您所見,在每批數據集中計算預測,即logits
。 logits
和 ground truth,即正確的y
值,作為 arguments 提供給accuracy.update_state
,就像在model.fit()
中沒有看到它一樣。 即使順序相同, y_true
和y
都是基本事實,而y_pred
和logits
是預測。
我希望這讓事情變得更清楚了。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.