简体   繁体   中英

Keras implementation: Custom loss function with uncertainty output

For a regression task I'd like to customize the loss function to output a certainty measure additionally.

The initial normal network would be:

model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(1, kernel_initializer='normal'))
model.compile(loss='mean_squared_error', optimizer='adam')

I'd like to add a certainty indicator sigma to the loss function. Eg depending on how accurate the predictions are different sigma sizes lead to minimal loss.

loss = (y_pred-y_true)^2/(2*sigma^2) + log(sigma)

The final outputs of the NN would then be y_pred and sigma .

I'm a bit lost in the implementation (new to keras):

  1. Where would we initialize/store sigma for it to be updated around recurring, similar datapoints during training.
  2. How do we connect the variable sigma from the loss function to the second NN output.

My current base stucture, where I'm obviously lacking the pieces

def custom_loss(y_true, y_pred, sigma):
    loss = pow((y_pred - y_true), 2)/(2 * pow(sigma, 2))+math.log(sigma)
    return loss, sigma

model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(2, kernel_initializer='normal'))
model.compile(loss=custom_loss, optimizer='adam')

Any tips/guidances are highly appreciated. Thanks!

The key is to extend y_pred from a scalar to a vector

def custom_loss(y_true, y_pred):
    loss = pow((y_pred[0] - y_true), 2) / (2 * pow(y_pred[1], 2)) + \
           tf.math.log(y_pred[1])
    return loss

model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(2, kernel_initializer='normal'))
model.compile(loss=custom_loss, optimizer='adam')

The model then returns the sigma to the prediction.

Y = model.predict(X)  # Y = [prediction, sigma]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM