I want to calculate weighted mean squared error, where weights is one vector in the data. I wrote a custom code based on the suggestions available on stack overflow.
The function is provided below:
weighted_mse <- function(y_true, y_pred,weights){
# convert tensors to R objects
K <- backend()
y_true <- K$eval(y_true)
y_pred <- K$eval(y_pred)
weights <- K$eval(weights)
# calculate the metric
loss <- sum(weights*((y_true - y_pred)^2))
# convert to tensor
return(K$constant(loss))
}
However, I am not sure how to pass the custom function to the compiler. It would be great if someone can help me. Thank you.
model <- model %>% compile(
loss = 'mse',
optimizer = 'rmsprop',
metrics = 'mse')
Regards
I haven't used Keras with R but, following the example from the documentation , probably this should work:
weighted_mse <- function(y_true, y_pred, weights){
K <- backend()
weights <- K$variable(weights)
# calculate the metric
loss <- K$sum(weights * (K$pow(y_true - y_pred, 2)))
loss
}
metric_weighted_mse <- custom_metric("weighted_mse", function(y_true, y_pred) {
weighted_mse(y_true, y_pred, weights)
})
model <- model %>% compile(
loss = 'mse',
optimizer = 'rmsprop',
metrics = metric_weighted_mse)
Note that I'm using a wrapper for the loss function because it has an extra parameter. Also, the loss function process the inputs as tensors, that is why you should convert the weights with K$variable(weights)
.
You can't eval
in loss funtions. This will break the graph.
You should just use the sample_weight
parameter of the fit
method: https://keras.rstudio.com/reference/fit.html
##not sure if this is valid R, but
##at some point you will call `fit` for training with `X_train` and `Y_train`,
##so, just add the weights.
history <- model$fit(X_train, Y_train, ..., sample_weight = weights)
That's all (don't use a custom loss).
Just for knowledge - Passing loss functions to compile
Only works for functions taking y_true
and y_pred
. (Not necessary if you're using sample_weights
)
model <- model %>% compile(
loss = weighted_mse,
optimizer = 'rmsprop',
metrics = 'mse')
But this won't work, you need something similar to the wrapper created by @spadarian.
Also, it will be very complicated to keep a correlation between your data and the weights, both because Keras will divide your data in batches and also because the data will be shuffled.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.