简体   繁体   中英

Mixed loss function for multiple output keras model

I have a keras model with a single input and two outputs. The two outputs are separated because one output has a linear activation (for estimation of a linear regression value) the other output has a softmax activation (trying to experiment with learning a confidence value due to noisy input data).

in_layer = Input((1,))
Hlayer1 = Dense(4,activation='linear')(in_layer)
Hlayer2 = Dense(4,activation='relu')(Hlayer1)
out_1 = Dense(1, activation='linear')(Hlayer2)
out_2 = Dense(1, activation='softmax')(Hlayer2) 
model = Model(inputs=[in_layer], outputs = [out_1,out_2])

I'd like to create a mixed loss function of the form:

loss = (1 - out_2) x MSE(out_1) + out_2 x MSE(out_1)

In an attempt to try to capture uncertainty in the answer with the out_2 prediction, and the actual answer in the out_1 prediction.

I've tried writing a custom loss function and can get toy examples to work on a single output model, but with a multi-output model the loss function seems to be called separately for each output so I'm struggling to access the variables needed for a mixed loss function.

Any advice on achieving this?

Thanks!

Here is a workaround if you would like to stay in Keras without going to tf. 1) Concatenate two inputs as one tensor as model output 2) Write custom loss functions

The example code could look like

in_layer = Input((1,))
Hlayer1 = Dense(4,activation='linear')(in_layer)
Hlayer2 = Dense(4,activation='relu')(Hlayer1)
out_1 = Dense(1, activation='linear')(Hlayer2)
out_2 = Dense(1, activation='softmax')(Hlayer2) 
out_concat = Concatenate(axis=-1)([out_1, out_2])
model = Model(inputs=[in_layer], outputs = out_concat)

def my_loss(y_true, y_pred):
    out_1 = y_pred[:, 0:1]
    out_2 = y_pred[:, 1:2]
    return your_loss_function

Cheers,

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM