简体   繁体   中英

subset in data.table when groups are empty

For these data

library(data.table)
set.seed(42)
dat <- data.table(id=1:12, group=rep(1:3, each=4), x=rnorm(12))

> dat
    id group           x
 1:  1     1  1.37095845
 2:  2     1 -0.56469817
 3:  3     1  0.36312841
 4:  4     1  0.63286260
 5:  5     2  0.40426832
 6:  6     2 -0.10612452
 7:  7     2  1.51152200
 8:  8     2 -0.09465904
 9:  9     3  2.01842371
10: 10     3 -0.06271410
11: 11     3  1.30486965
12: 12     3  2.28664539

My goal is to get, from each group, the first id for which x is larger than some threshold, say x>1.5 .

> dat[x>1.5, .SD[1], by=group]
   group id        x
1:     2  7 1.511522
2:     3  9 2.018424

is indeed correct but I am unhappy about that fact that it silently yields no result for group 1. Instead, I would like it to yield the last id of each group for which no id fulfills the condition. I see that I could achieve this in two steps

> tmp <- dat[x>1.5, .SD[1], by=group]
> rbind(tmp,dat[!group%in%tmp$group,.SD[.N], by=group])
   group id         x
1:     2  7 1.5115220
2:     3  9 2.0184237
3:     1  4 0.6328626

but I am sure I am not making full use of the data.table capabilities here, which must permit a more elegant solution.

Using data.table , we can check for a condition and subset row by group.

library(data.table)
dat[dat[, if(any(x>1.5)) .I[which.max(x > 1.5)] else .I[.N], by=group]$V1]

#   id group         x
#1:  4     1 0.6328626
#2:  7     2 1.5115220
#3:  9     3 2.0184237

The dplyr , translation of that would be

library(dplyr)
dat %>%
  group_by(group) %>%
  slice(if(any(x > 1.5)) which.max(x > 1.5) else n())

Or more efficiently

dat[, .SD[{temp = x > 1.5; if (any(temp)) which.max(temp) else .N}], by = group]

Thanks to @IceCreamTouCan, @sindri_baldur and @jangorecki for their valuable suggestions to improve this answer.

You could subset both ways (which are optimized by GForce) and then combine them:

D1 = dat[x>1.5, lapply(.SD, first), by=group]
D2 = dat[, lapply(.SD, last), by=group]
rbind(D1, D2[!D1, on=.(group)])

   group id         x
1:     2  7 1.5115220
2:     3  9 2.0184237
3:     1  4 0.6328626

There is some inefficiency here since we are grouping by group three times. I'm not sure if that will be outweighed by more efficient calculations in j thanks to GForce or not. @jangorecki points out that the inefficiency of grouping three times might be mitigated by setting the key first.

Comment : I used last(.SD) since .SD[.N] is not yet optimized and last(.SD) throws an error. I changed the OP's code to use lapply first for the sake of symmetry.

另一个选择是:

dat[x>1.5 | group!=shift(group, -1L), .SD[1L], .(group)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM