簡體   English   中英

Keras 多輸出,自定義損失函數

[英]Keras multiple outputs, customed loss function

我試圖構建一個具有兩個輸入和兩個輸出的模型。 模型的結構如下。 我想構建一個包含兩部分的客戶損失函數:'d_flat' 和 't_flat' 之間的差異,以及層 'perdict' 的分類交叉熵損失。 模型是這樣的:

initial_input_domain=tf.keras.Input(shape=(36,36,3))
initial_input_target=tf.keras.Input(shape=(36,36,3))

vgg_base=tf.keras.applications.VGG19(include_top=False,#weights='imagenet',
                                     input_shape=(36,36,3))

domain1=vgg_base(initial_input_domain)
target1=vgg_base(initial_input_target)

d_flat = tf.keras.layers.Flatten(name='d_flat')(domain1)
predictions=tf.keras.layers.Dense(num_classes,name='predict', activation='sigmoid')(d_flat)

t_flat = tf.keras.layers.Flatten(name='t_flat')(target1)
predictions_t=tf.keras.layers.Dense(num_classes,name='predict_t', activation='sigmoid')(t_flat)

fin_model=tf.keras.Model(inputs=[initial_input_domain,initial_input_target], outputs=[predictions, predictions_t])

在此處輸入圖片說明

我寫的損失函數是這樣的:

def Total_loss(d_flat, t_flat):

    def loss_function(y_true, y_pred):

        Dist_LOSS = 'something does not matter' # the difference of two layers
        loss = K.categorical_crossentropy(y_true,y_pred) + Dist_LOSS
        return loss

    return loss_function

所以我的問題是這個函數中的 y_pred 和 y_true 是什么? 我只希望這個函數計算'預測'的分類交叉熵損失,這是左邊的部分。 我該怎么做才能使 keras 不計算正確部分的分類交叉熵損失? 看起來 y_pred 和 y_true 是左右分支的組合。 (我用在右邊的標簽是正確的標簽,我用在右邊的都是0,沒有任何意義)

Keras 生成這些輸出,

Epoch 1/100
6912/6912 [==============================] - 24s 3ms/sample - loss: 0.0315 - predict_loss: 0.0270 - predict_t_loss: 0.0045 - predict_categorical_accuracy: 0.9931 - predict_t_categorical_accuracy: 0.6413

看起來損失 = predict_loss + predict_t_loss。 它應該是任何 predict_t_loss。 任何建議表示贊賞。 謝謝!

自定義損失函數只能與(y_true, y_pred) 如果您想使用在最后一層之前定義的其他變量,例如d_flat, t_flat或僅輸出的一部分,則必須使用model.add_loss 正如您在 API 中所見,您可以在您自己的自定義層(為您提供更具體的控制)或在模型本身上定義它。 在您的情況下,您可以執行以下操作:

model.add_loss(d_flat - f_flat + K.categorical_crossentropy(predictions,y_pred_for_left_output)

其中y_pred_for_left_output是此輸出節點的標簽張量。 通過這種方式,您將損失定義為平坦層與僅左側輸出節點的 CE 之間的差異。 您可以根據您的特定要求進行調整,但這將是正確的方法。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM