简体   繁体   中英

How do i specify to a model what to take as input of a custom loss function?

I'm having issues in understanding/implementing a custom loss function in my model.

I have a keras model which is composed by 3 sub models as you can see here in the model architecture,

模型架构

Now, I'd like to use the outputs of model and model_2 in my custom loss function. I understand that in the loss function definition I can write:

    def custom_mse(y_true, y_pred):
        *calculate stuff*
    return loss

But how do I tell the model to take its 2 outputs as inputs of the loss function?

Maybe, and i hope so, it's super trivial but I didn't find anything online, if you could help me it'd be fantastic.

Thanks in advance

Context: model and model_2 are the same pretrained model, a binary classifier, which predicts the interaction between 2 inputs (of image-like type). model_1 is a generative model which will edit one of the inputs.

Therefore:

    complete_model = Model(inputs=[input_1, input_2], outputs=[out_model, out_model2])
    opt = *an optimizer*
    
    complete_model.compile(loss=custom_mse,
                           ??????,
                           optimizer = opt,
                           metrics=['whatever'])

The main goal is to compare the prediction with the edited input against the one with the un-edited input, therefore the model will outputs the 2 interactions, which i need to use in the loss function.

EDIT: Thank you Andrey for the solution,

Now however i can't manage to implement together the 2 loss functions, namely the one with add_loss(func) and a classic binary_crossentropy in model.complie(loss='binary_crossentropy', ...).

Can I maybe add an add_loss specifying model_2.output and the label? If yes do you know how?

They work by themselves but not together, when i try to run the code they raise

ValueError: Shapes must be equal rank, but are 0 and 4 From merging shape 0 with other shapes. for '{{node AddN}} = AddN[N=2, T=DT_FLOAT](binary_crossentropy/weighted_loss/value, complete_model/generator/tf_op_layer_SquaredDifference_3/SquaredDifference_3)' with input shapes: [], [?,500,400,1].

You can add loss with compile() only for standard loss function signature (y_true, y_pred). You can not use it because your signature is something like (y_true, (y_pred1, y_pred2)). Use add_loss() API instead. See here: https://keras.io/api/losses/

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM