簡體   English   中英

Tensorflow 一種用於多輸出模型的自定義指標

[英]Tensorflow one custom metric for multioutput models

我在文檔中找不到信息,所以我在這里問。

我有一個具有 3 個不同輸出的多輸出模型:

model = tf.keras.Model(inputs=[input], outputs=[output1, output2, output3])

用於驗證的預測標簽由這 3 個輸出構成,僅形成一個,這是一個后處理步驟。 用於訓練的數據集是這 3 個中間輸出的數據集,為了驗證,我評估了標簽數據集而不是 3 種中間數據。

我想使用自定義指標來評估我的模型,該指標處理后處理並與基本事實進行比較。

我的問題是,在自定義指標的代碼中, y_pred會是模型的 3 個輸出的列表嗎?

class MyCustomMetric(tf.keras.metrics.Metric):

  def __init__(self, name='my_custom_metric', **kwargs):
    super(MyCustomMetric, self).__init__(name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    # ? is y_pred a list [batch_output_1, batch_output_2, batch_output_3] ? 

  def result(self):
    pass 

# one single metric handling the 3 outputs?
model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(0.01),
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=[MyCustomMetric()])

根據您給定的模型定義,這是一個標准的多輸出模型。

model = tf.keras.Model(inputs=[input], outputs=[output_1, output_2, output_3])

通常,所有(自定義)指標以及(自定義)損失將分別在每個輸出上調用(如 y_pred)! 在損失/度量函數中,您只會看到一個輸出以及一個相應的目標張量。 通過傳遞損失函數列表(長度 == 模型的輸出數量),您可以指定哪個損失將用於哪個輸出:

model.compile(optimizer=Adam(), loss=[loss_for_output_1, loss_for_output_2, loss_for_output_3], loss_weights=[1, 4, 8])

總損失(即要最小化的目標函數)將是所有損失乘以給定損失權重的加法組合。

指標幾乎相同! 在這里,您可以傳遞(至於損失)一個指標列表(長度 == 輸出數量),並告訴 Keras 將哪個指標用於您的哪個模型輸出。

model.compile(optimizer=Adam(), loss='mse', metrics=[metrics_for_output_1, metrics_for_output2, metrics_for_output3])

這里的metrics_for_output_X 可以是一個函數,也可以是一個函數列表,它們都被調用,其中一個對應的output_X 作為y_pred。

這在 Keras 中的多輸出模型文檔中有詳細解釋。 他們還展示了使用字典(將損失/度量函數映射到特定輸出)而不是列表的示例。 https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models

更多信息:

如果我理解正確,您想使用損失函數將三個模型輸出與三個真實值進行比較來訓練您的模型,並希望通過比較來自三個模型輸出的派生值和單個真實值來進行某種性能評估. 通常,模型會根據評估它的相同目標進行訓練,否則在評估模型時可能會得到更差的結果!

無論如何......為了在單個標簽上評估您的模型,我建議您:

1.(清潔液)

重寫您的模型並合並后處理步驟。 添加所有必要的操作(作為圖層)並將它們映射到輔助輸出。 為了訓練您的模型,您可以將輔助輸出的 loss_weight 設置為零。 合並您的數據集,以便您可以為模型提供模型輸入、中間目標輸出以及標簽。 如上所述,您現在可以定義一個指標,將輔助模型輸出與給定的目標標簽進行比較。

2.

或者,您可以通過在 model.predict(input) 的三個輸出上計算您的后處理步驟來訓練您的模型並導出度量,例如在自定義回調中。 如果您想在張量板中跟蹤這些值,則需要編寫自定義摘要! 這就是我不推薦此解決方案的原因。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM