簡體   English   中英

Keras中的自定義損失輸出

[英]Output of custom loss in Keras

我知道在Keras中處理自定義損失函數有很多問題,但是即使經過3個小時的搜尋,我也無法回答。

這是我的問題的非常簡化的示例。 我意識到這個例子是沒有意義的,但是為了簡單起見,我提供了它,顯然我需要實現更復雜的東西。

from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):

    zeros = tf.zeros_like(y_true)
    index_of_zeros = tf.where(tf.equal(zeros, y_true))
    ones = tf.ones_like(y_true)
    index_of_ones = tf.where(tf.equal(ones, y_true))

    zero = tf.gather(y_pred, index_of_zeros)
    one = tf.gather(y_pred, index_of_ones)

    loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
    loss_1 = binary_crossentropy(tf.ones_like(one), one)

    return mean(tf.concat([loss_0, loss_1], axis=0))

我不明白為什么在兩類數據集上使用上述損失函數訓練網絡不會產生與使用內置binary-crossentropy損失函數進行訓練的結果相同的結果。 謝謝!

編輯 :我編輯了代碼片段,以包括以下每個注釋的平均值。 我仍然得到相同的行為。

我終於弄明白了。 當形狀為“未知”時, tf.where函數的行為會非常不同。 要修復上面的代碼段,只需在函數聲明后立即插入以下幾行:

y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM