I'm having issues in understanding/implementing a custom loss function in my model.
I have a keras model which is composed by 3 sub models as you can see here in the model architecture,
Now, I'd like to use the outputs of model and model_2 in my custom loss function. I understand that in the loss function definition I can write:
def custom_mse(y_true, y_pred):
*calculate stuff*
return loss
But how do I tell the model to take its 2 outputs as inputs of the loss function?
Maybe, and i hope so, it's super trivial but I didn't find anything online, if you could help me it'd be fantastic.
Thanks in advance
Context: model and model_2 are the same pretrained model, a binary classifier, which predicts the interaction between 2 inputs (of image-like type). model_1 is a generative model which will edit one of the inputs.
Therefore:
complete_model = Model(inputs=[input_1, input_2], outputs=[out_model, out_model2])
opt = *an optimizer*
complete_model.compile(loss=custom_mse,
??????,
optimizer = opt,
metrics=['whatever'])
The main goal is to compare the prediction with the edited input against the one with the un-edited input, therefore the model will outputs the 2 interactions, which i need to use in the loss function.
EDIT: Thank you Andrey for the solution,
Now however i can't manage to implement together the 2 loss functions, namely the one with add_loss(func) and a classic binary_crossentropy in model.complie(loss='binary_crossentropy', ...).
Can I maybe add an add_loss specifying model_2.output and the label? If yes do you know how?
They work by themselves but not together, when i try to run the code they raise
ValueError: Shapes must be equal rank, but are 0 and 4 From merging shape 0 with other shapes. for '{{node AddN}} = AddN[N=2, T=DT_FLOAT](binary_crossentropy/weighted_loss/value, complete_model/generator/tf_op_layer_SquaredDifference_3/SquaredDifference_3)' with input shapes: [], [?,500,400,1].
You can add loss with compile()
only for standard loss function signature (y_true, y_pred). You can not use it because your signature is something like (y_true, (y_pred1, y_pred2)). Use add_loss()
API instead. See here: https://keras.io/api/losses/
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.