簡體   English   中英

以下代碼是否為 Keras 中的多類分類提供召回?

[英]Does following code give recall for multiclass classification in Keras?

以下代碼是否為 Keras 中的多類分類提供召回? 即使我在 model.compile 中調用召回 function 時沒有傳遞 y_true 和 y_pred,它也向我展示了召回的結果。

def recall(y_true, y_pred):
    y_true = K.ones_like(y_true) 
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    
    recall = true_positives / (all_positives + K.epsilon())
    return recall

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=[recall])

是的,它有效,因為在model.fit()內部多次調用召回,指定這些值。

它的工作方式與類似(更復雜和優化):

accuracy = tf.keras.metrics.CategoricalAccuracy()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for step, (x, y) in enumerate(dataset):
    with tf.GradientTape() as tape:
        logits = model(x)
        # Compute the loss value for this batch.
        loss_value = loss_fn(y, logits)

    # Update the state of the `accuracy` metric.
    accuracy.update_state(y, logits)

    # Update the weights of the model to minimize the loss value.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    # Logging the current accuracy value so far.
    if step % 100 == 0:
        print('Step:', step)        
        print('Total running accuracy so far: %.3f' % accuracy.result())

這稱為漸變磁帶,可用於執行自定義的訓練循環。 基本上,它公開了在 model 的可訓練張量上計算的梯度。 它允許您手動更新 model 的權重,因此對於個性化非常有用。 所有這些東西也在model.fit()內部自動完成。 你不需要這個,它只是為了解釋事情是如何工作的。

如您所見,在每批數據集中計算預測,即logits logits和 ground truth,即正確的y值,作為 arguments 提供給accuracy.update_state ,就像在model.fit()中沒有看到它一樣。 即使順序相同, y_truey都是基本事實,而y_predlogits是預測。

我希望這讓事情變得更清楚了。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM