简体   繁体   中英

yardstick::rmse on grouped data returns error and incorrect results

I wanted to evaluate the performance of several regression model and used the yardstick package to calculate the RMSE. Here is some example data

  model obs pred
1     A   1    1
2     B   1    2
3     C   1    3

When I run the following code

library(yardstick)
library(dplyr)
dat %>%
 group_by(model) %>%
 summarise(RMSE = yardstick::rmse(truth = obs, estimate = pred))

I get the following error

Error in summarise_impl(.data, dots) : no applicable method for 'rmse' applied to an object of class "c('double', 'numeric')".

However, when I explicitly supply . as the first argument (which should not be necessary, I thought), I get no error, but the results are incorrect.

dat %>%
 group_by(model) %>%
 summarise(RMSE = yardstick::rmse(., truth = obs, estimate = pred))
# A tibble: 3 x 2
  model   RMSE
  <fctr> <dbl>
1 A       1.29
2 B       1.29
3 C       1.29

I was expecting the following

# A tibble: 3 x 2
  model   RMSE
  <fctr> <dbl>
1 A       0
2 B       1.00
3 C       2.00

I know that there are alternatives to this function but still I don't understand this behavior.

data

dat <- structure(list(model = structure(1:3, .Label = c("A", "B", "C"), class = "factor"), obs = c(1, 1, 1), pred = 1:3), .Names = c("model", "obs", "pred"), row.names = c(NA, -3L), class = "data.frame")

Based on the help page ?yardstick::rmse , it looks like it expects a data frame as its first argument, which explains the error you're getting.

I'm not quite up to speed on that new package, but it seems that the function expects to calculate a summary statistic across a data frame, rather than a row-by-row calculation. To force it to run it row-by-row, you'd need to make it think that each row is its own dataframe, and apply the function within each of those data frames:

library(tidyverse)
dat %>%
  group_by(model) %>%
  nest() %>% 
  mutate(rmse_res = map(data, rmse, truth = obs, estimate = pred)) %>% 
  unnest(rmse_res)

# A tibble: 3 x 3
  model  data              rmse
  <fctr> <list>           <dbl>
1 A      <tibble [1 x 2]>  0   
2 B      <tibble [1 x 2]>  1.00
3 C      <tibble [1 x 2]>  2.00

We can use the do function to apply the rmse function to every group.

dat %>%
  group_by(model) %>%
  do(data_frame(model = .$model[1], obs = .$obs[1], pred = .$pred[1], 
     RMSE = yardstick::rmse(., truth = obs, estimate = pred)))
# # A tibble: 3 x 4
# # Groups: model [3]
# model    obs  pred  RMSE
#  <fctr> <dbl> <int> <dbl>
# 1 A       1.00     1  0   
# 2 B       1.00     2  1.00
# 3 C       1.00     3  2.00

Or we can split the data frame and apply the rmse function.

dat %>%
  mutate(RMSE = dat %>%
           split(.$model) %>%
           sapply(yardstick::rmse, truth = obs, estimate = pred))
#   model obs pred RMSE
# 1     A   1    1    0
# 2     B   1    2    1
# 3     C   1    3    2

Or we can nest the obs and pred column to a list column and then apply the rmse function.

library(tidyr)

dat %>%
  nest(obs, pred) %>%
  mutate(RMSE = sapply(data, yardstick::rmse, truth = obs, estimate = pred)) %>%
  unnest()
#   model RMSE obs pred
# 1     A    0   1    1
# 2     B    1   1    2
# 3     C    2   1    3

The output of these three methods are a little bit different, but all contain the right RMSE calculation. Here I use the microbenchmark package to conduct a performance evaluation.

library(microbenchmark)

microbenchmark(m1 = {dat %>%
    group_by(model) %>%
    do(data_frame(model = .$model[1], obs = .$obs[1], pred = .$pred[1], 
                  RMSE = yardstick::rmse(., truth = obs, estimate = pred)))},
    m2 = {dat %>%
        mutate(RMSE = dat %>%
                 split(.$model) %>%
                 sapply(yardstick::rmse, truth = obs, estimate = pred))},
    m3 = {dat %>%
        nest(obs, pred) %>%
        mutate(RMSE = sapply(data, yardstick::rmse, truth = obs, estimate = pred)) %>%
        unnest()})

# Unit: milliseconds
# expr      min       lq     mean   median       uq       max neval
#   m1 43.18746 46.71055 50.23383 48.46554 51.05639 174.46371   100
#   m2 14.08516 14.78093 16.14605 15.74505 16.89936  24.02136   100
#   m3 28.99795 30.90407 32.71092 31.89954 33.94729  44.57953   100

The result shows that m2 is the fastest, while m1 is the slowest. I think the implication is do operation is usually slower then other methods, so if possible, we should avoid the do operation. Although m2 is the fastest, personally I like the syntax of m3 the best. The nested data frame will allow us to easily summarize information between different models or different groups.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM