简体   繁体   English

自定义分组dplyr函数(sample_n)

[英]custom grouped dplyr function (sample_n)

I am trying to apply a sampling function in a grouped fashion to a data frame, where it should sample n samples from each group, or all group members if the group size is smaller than n . 我试图以分组方式将采样函数应用于数据框,其中它应该从每个组中采样n个样本, 或者如果组大小小于 n则采样所有组成员

Using dplyr, I first tried 使用dplyr,我首先尝试了

library(dplyr)
mtcars %>% group_by(cyl) %>% sample_n(2)

This works when n is smaller than all the group sizes but does not take the full group when I choose n larger than the group size (note that there are 7 cars in one of the cyl groups): 当n小于所有组大小但是当我选择大于组大小的n时不会占用整个组时(注意其中一个cyl组中有7个车辆),这种方法有效:

mtcars %>% group_by(cyl) %>% sample_n(8)
Error: `size` must be less or equal than 7 (size of data), 
set `replace` = TRUE to use sampling with replacement

I tried to solve this by creating an adapted group_n function like so: 我尝试通过创建一个适合的group_n函数来解决这个问题:

sample_n_or_all <- function(tbl, n) {
  if (nrow(tbl) < n)return(tbl)
  sample_n(tbl, n)
}

but using my custom function ( mtcars %>% group_by(cyl) %>% sample_n_or_all(8) ) generates the same error. 但使用我的自定义函数( mtcars %>% group_by(cyl) %>% sample_n_or_all(8) )会生成相同的错误。

Any suggestions how I can adapt my function so I can apply it to each of the groups? 有什么建议我可以如何调整我的功能,以便我可以将它应用于每个组? Or another solution to the problem? 或者问题的另一个解决方案?

We could check the number of rows in the group and pass the value to sample_n accordingly. 我们可以检查组中的行数,并相应地将值传递给sample_n

library(dplyr)
n <- 8

temp <- mtcars %>% group_by(cyl) %>% sample_n(if(n() < n) n() else n) 
temp

#    mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
# 1  21.4     4 121     109  4.11  2.78  18.6     1     1     4     2
# 2  27.3     4  79      66  4.08  1.94  18.9     1     1     4     1
# 3  24.4     4 147.     62  3.69  3.19  20       1     0     4     2
# 4  22.8     4 108      93  3.85  2.32  18.6     1     1     4     1
# 5  26       4 120.     91  4.43  2.14  16.7     0     1     5     2
# 6  33.9     4  71.1    65  4.22  1.84  19.9     1     1     4     1
# 7  30.4     4  75.7    52  4.93  1.62  18.5     1     1     4     2
# 8  30.4     4  95.1   113  3.77  1.51  16.9     1     1     5     2
# 9  21       6 160     110  3.9   2.62  16.5     0     1     4     4
#10  17.8     6 168.    123  3.92  3.44  18.9     1     0     4     4
# … with 13 more rows

We can check number of rows in each group after that. 之后我们可以检查每组中的行数。

table(temp$cyl)

#4 6 8 
#8 7 8 

table(mtcars$cyl)

# 4  6  8 
#11  7 14 

We can do this without using a logical condition with pmin 我们可以在不使用pmin的逻辑条件的情况下执行此操作

library(dplyr)
tmp <- mtcars %>%
         group_by(cyl) %>%
         sample_n(pmin(n(), n))
# A tibble: 23 x 11
# Groups:   cyl [3]
#     mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
# 1  33.9     4  71.1    65  4.22  1.84  19.9     1     1     4     1
# 2  27.3     4  79      66  4.08  1.94  18.9     1     1     4     1
# 3  21.4     4 121     109  4.11  2.78  18.6     1     1     4     2
# 4  30.4     4  75.7    52  4.93  1.62  18.5     1     1     4     2
# 5  21.5     4 120.     97  3.7   2.46  20.0     1     0     3     1
# 6  32.4     4  78.7    66  4.08  2.2   19.5     1     1     4     1
# 7  30.4     4  95.1   113  3.77  1.51  16.9     1     1     5     2
# 8  26       4 120.     91  4.43  2.14  16.7     0     1     5     2
# 9  17.8     6 168.    123  3.92  3.44  18.9     1     0     4     4
#10  21       6 160     110  3.9   2.62  16.5     0     1     4     4
# … with 13 more rows

-checking -检查

table(tmp$cyl)
# 4 6 8 
# 8 7 8 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM