I wanted to evaluate the performance of several regression model and used the yardstick
package to calculate the RMSE. Here is some example data
model obs pred
1 A 1 1
2 B 1 2
3 C 1 3
When I run the following code
library(yardstick)
library(dplyr)
dat %>%
group_by(model) %>%
summarise(RMSE = yardstick::rmse(truth = obs, estimate = pred))
I get the following error
Error in summarise_impl(.data, dots) : no applicable method for 'rmse' applied to an object of class "c('double', 'numeric')".
However, when I explicitly supply .
as the first argument (which should not be necessary, I thought), I get no error, but the results are incorrect.
dat %>%
group_by(model) %>%
summarise(RMSE = yardstick::rmse(., truth = obs, estimate = pred))
# A tibble: 3 x 2
model RMSE
<fctr> <dbl>
1 A 1.29
2 B 1.29
3 C 1.29
I was expecting the following
# A tibble: 3 x 2
model RMSE
<fctr> <dbl>
1 A 0
2 B 1.00
3 C 2.00
I know that there are alternatives to this function but still I don't understand this behavior.
data
dat <- structure(list(model = structure(1:3, .Label = c("A", "B", "C"), class = "factor"), obs = c(1, 1, 1), pred = 1:3), .Names = c("model", "obs", "pred"), row.names = c(NA, -3L), class = "data.frame")
Based on the help page ?yardstick::rmse
, it looks like it expects a data frame as its first argument, which explains the error you're getting.
I'm not quite up to speed on that new package, but it seems that the function expects to calculate a summary statistic across a data frame, rather than a row-by-row calculation. To force it to run it row-by-row, you'd need to make it think that each row is its own dataframe, and apply the function within each of those data frames:
library(tidyverse)
dat %>%
group_by(model) %>%
nest() %>%
mutate(rmse_res = map(data, rmse, truth = obs, estimate = pred)) %>%
unnest(rmse_res)
# A tibble: 3 x 3
model data rmse
<fctr> <list> <dbl>
1 A <tibble [1 x 2]> 0
2 B <tibble [1 x 2]> 1.00
3 C <tibble [1 x 2]> 2.00
We can use the do
function to apply the rmse
function to every group.
dat %>%
group_by(model) %>%
do(data_frame(model = .$model[1], obs = .$obs[1], pred = .$pred[1],
RMSE = yardstick::rmse(., truth = obs, estimate = pred)))
# # A tibble: 3 x 4
# # Groups: model [3]
# model obs pred RMSE
# <fctr> <dbl> <int> <dbl>
# 1 A 1.00 1 0
# 2 B 1.00 2 1.00
# 3 C 1.00 3 2.00
Or we can split the data frame and apply the rmse
function.
dat %>%
mutate(RMSE = dat %>%
split(.$model) %>%
sapply(yardstick::rmse, truth = obs, estimate = pred))
# model obs pred RMSE
# 1 A 1 1 0
# 2 B 1 2 1
# 3 C 1 3 2
Or we can nest the obs
and pred
column to a list column and then apply the rmse
function.
library(tidyr)
dat %>%
nest(obs, pred) %>%
mutate(RMSE = sapply(data, yardstick::rmse, truth = obs, estimate = pred)) %>%
unnest()
# model RMSE obs pred
# 1 A 0 1 1
# 2 B 1 1 2
# 3 C 2 1 3
The output of these three methods are a little bit different, but all contain the right RMSE calculation. Here I use the microbenchmark
package to conduct a performance evaluation.
library(microbenchmark)
microbenchmark(m1 = {dat %>%
group_by(model) %>%
do(data_frame(model = .$model[1], obs = .$obs[1], pred = .$pred[1],
RMSE = yardstick::rmse(., truth = obs, estimate = pred)))},
m2 = {dat %>%
mutate(RMSE = dat %>%
split(.$model) %>%
sapply(yardstick::rmse, truth = obs, estimate = pred))},
m3 = {dat %>%
nest(obs, pred) %>%
mutate(RMSE = sapply(data, yardstick::rmse, truth = obs, estimate = pred)) %>%
unnest()})
# Unit: milliseconds
# expr min lq mean median uq max neval
# m1 43.18746 46.71055 50.23383 48.46554 51.05639 174.46371 100
# m2 14.08516 14.78093 16.14605 15.74505 16.89936 24.02136 100
# m3 28.99795 30.90407 32.71092 31.89954 33.94729 44.57953 100
The result shows that m2
is the fastest, while m1
is the slowest. I think the implication is do
operation is usually slower then other methods, so if possible, we should avoid the do
operation. Although m2
is the fastest, personally I like the syntax of m3
the best. The nested data frame will allow us to easily summarize information between different models or different groups.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.