简体   繁体   中英

Tensorflow one custom metric for multioutput models

I can't find the info in the documentation so I am asking here.

I have a multioutput model with 3 different outputs:

model = tf.keras.Model(inputs=[input], outputs=[output1, output2, output3])

The predicted labels for validation are constructed from these 3 outputs to form only one, it's a post-processing step. The dataset used for training is a dataset of those 3 intermediary outputs, for validation I evaluate on a dataset of labels instead of the 3 kind of intermediary data.

I would like to evaluate my model using a custom metric that handle the post processing and comparaison with the ground truth.

My question is , in the code of the custom metric, will y_pred be a list of the 3 outputs of the model?

class MyCustomMetric(tf.keras.metrics.Metric):

  def __init__(self, name='my_custom_metric', **kwargs):
    super(MyCustomMetric, self).__init__(name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    # ? is y_pred a list [batch_output_1, batch_output_2, batch_output_3] ? 

  def result(self):
    pass 

# one single metric handling the 3 outputs?
model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(0.01),
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=[MyCustomMetric()])

With your given model definition, this is a standard multi-output Model.

model = tf.keras.Model(inputs=[input], outputs=[output_1, output_2, output_3])

In general, all (custom) Metrics as well as (custom) Losses will be called on every output separately (as y_pred)! Within the loss/metric function you will only see one output together with the one corresponding target tensor. By passing a list of loss functions (length == number of outputs of your model) you can specifiy which loss will be used for which output:

model.compile(optimizer=Adam(), loss=[loss_for_output_1, loss_for_output_2, loss_for_output_3], loss_weights=[1, 4, 8])

The total loss (which is the objective function to minimize) will be the additive combination of all losses multiplied with the given loss weights.

It is almost the same for the metrics! Here you can pass (as for the loss) a list (lenght == number of outputs) of metrics and tell Keras which metric to use for which of your model outputs.

model.compile(optimizer=Adam(), loss='mse', metrics=[metrics_for_output_1, metrics_for_output2, metrics_for_output3])

Here metrics_for_output_X can be either a function or a list of functions, which all be called with the one corresponding output_X as y_pred.

This is explained in detail in the documentation of Multi-Output Models in Keras. They also show examples for using dictionarys (to map loss/metric functions to a specific output) instead of lists. https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models

Further information:

If I understand you correctly you want to train your model using a loss function comparing the three model outputs with three ground truth values and want to do some sort of performance evaluation by comparing a derived value from the three model outputs and a single ground truth value. Usually the model gets trained on the same objective it is evaluated on, otherwise you might get poorer results when evaluating your model!

Anyways... for evaluating your model on a single label I suggest you either:

1. (The clean solution)

Rewrite your model and incorporate the post-processing steps. Add all the necessary operations (as layers) and map those to an auxiliary output. For training your model you can set the loss_weight of the auxiliary output to zero. Merge your Datasets so you can feed your model the model input, the intermediate target outputs as well as the labels. As explained above you can define now a metric comparing the auxiliary model output with the given target labels.

2.

Or you train your model and derive the metric eg in a custom Callback by calculating your post-processing steps on the three outputs of model.predict(input). This will make it necessary to write custom summaries if you want to track those values in your tensorboard! That's why I would not recommend this solution.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM