[英]Keras implementation: Custom loss function with uncertainty output
For a regression task I'd like to customize the loss function to output a certainty measure additionally. 对于回归任务,我想自定义损失函数以另外输出确定性度量。
The initial normal network would be: 最初的正常网络为:
model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(1, kernel_initializer='normal'))
model.compile(loss='mean_squared_error', optimizer='adam')
I'd like to add a certainty indicator sigma
to the loss function. 我想向损失函数添加确定性指标
sigma
。 Eg depending on how accurate the predictions are different sigma
sizes lead to minimal loss. 例如,取决于预测的准确性如何,不同的
sigma
大小会导致最小的损失。
loss = (y_pred-y_true)^2/(2*sigma^2) + log(sigma)
The final outputs of the NN would then be y_pred
and sigma
. NN的最终输出将是
y_pred
和sigma
。
I'm a bit lost in the implementation (new to keras): 我在实现中有点迷路(对keras来说是新的):
sigma
for it to be updated around recurring, similar datapoints during training. sigma
,以便在训练过程中围绕类似的重复数据点进行更新。 sigma
from the loss function to the second NN output. sigma
从损失函数连接到第二个NN输出。 My current base stucture, where I'm obviously lacking the pieces 我目前的基础课程,显然我缺少这些基础知识
def custom_loss(y_true, y_pred, sigma):
loss = pow((y_pred - y_true), 2)/(2 * pow(sigma, 2))+math.log(sigma)
return loss, sigma
model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(2, kernel_initializer='normal'))
model.compile(loss=custom_loss, optimizer='adam')
Any tips/guidances are highly appreciated. 任何提示/指导都受到高度赞赏。 Thanks!
谢谢!
The key is to extend y_pred from a scalar to a vector 关键是将y_pred从标量扩展为向量
def custom_loss(y_true, y_pred):
loss = pow((y_pred[0] - y_true), 2) / (2 * pow(y_pred[1], 2)) + \
tf.math.log(y_pred[1])
return loss
model = Sequential()
model.add(Dense(15, input_dim=7, kernel_initializer='normal'))
model.add(Dense(2, kernel_initializer='normal'))
model.compile(loss=custom_loss, optimizer='adam')
The model then returns the sigma to the prediction. 然后,模型将sigma返回到预测。
Y = model.predict(X) # Y = [prediction, sigma]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.