简体   繁体   中英

Using data.table for a conditional sum within groups with lapply

I have a data.table where each row is an event with a start date and end date, but the number of days between each start and end is variable. Therefore, I am attempting to count how many other events have already ended at the time each one begins. I can do this using lapply , but when I try to use data.table with the by functionality I don't get the expected output. Sample code below:

library(data.table)

DT <- data.table(
  start = as.Date(c("2018-07-01","2018-07-03","2018-07-06","2018-07-08","2018-07-12","2018-07-15")),
  end = as.Date(c("2018-07-10","2018-07-04","2018-07-09","2018-07-20","2018-07-14","2018-07-27")),
  group_id = c("a", "a", "a", "b", "b", "b"))

# This produces the expected output (0,0,1,1,3,4):
lapply(DT$start, function(x) sum(x > DT$end))

# This also works using data.table:
DT[, count := lapply(DT$start, function(x) sum(x > DT$end))]

# However, I don't get the expected output (0,0,1,0,0,1) when I attempt to do this by group_id
DT[, count_by_group := lapply(DT$start, function(x) sum(x > DT$end)), by = group_id]

With the following output, where count_by_group is not the expected result:

        start        end group_id count count_by_group
1: 2018-07-01 2018-07-10        a     0              0
2: 2018-07-03 2018-07-04        a     0              0
3: 2018-07-06 2018-07-09        a     1              0
4: 2018-07-08 2018-07-20        b     1              0
5: 2018-07-12 2018-07-14        b     3              0
6: 2018-07-15 2018-07-27        b     4              0

Can someone help me understand how by changes the behavior? I've also tried to use different versions of the .SD feature, but wasn't able to get that to work either.

unlist()

unlist() works as well:

DT[, count_by_group := unlist(lapply(start, function(x) sum(x > end))), by = group_id]

Non-equi join

Alternatively, this can also be solved by aggregating in a non-equi self join :

DT[, count_by_group := DT[DT, on = .(group_id, end < start), .N, by = .EACHI]$N]
DT
  start end group_id count_by_group 1: 2018-07-01 2018-07-10 a 0 2: 2018-07-03 2018-07-04 a 0 3: 2018-07-06 2018-07-09 a 1 4: 2018-07-08 2018-07-20 b 0 5: 2018-07-12 2018-07-14 b 0 6: 2018-07-15 2018-07-27 b 1 

Benchmark

The non-equi join is also the fastest method for cases with more than a few hundred rows:

library(bench)
bm <- press(
  n_grp = c(2L, 5L, 10L),
  n_row = 10^(2:4),
  {
    set.seed(1L)
    DT = data.table(
      group_id = sample.int(n_grp, n_row, TRUE),
      start = as.Date("2018-07-01") + rpois(n_row, 20L))
    DT[, end := start + rpois(n_row, 10L)]
    setorder(DT, group_id, start, end)
    mark(
      unlist = copy(DT)[, count_by_group := unlist(lapply(start, function(x) sum(x > end))), by = group_id],
      sapply = copy(DT)[, count_by_group := sapply(start, function(x) sum(x > end)), by = group_id],
      vapply = copy(DT)[, count_by_group := vapply(start, function(x) sum(x > end), integer(1)), by = group_id],
      nej = copy(DT)[, count_by_group := DT[DT, on = .(group_id, end < start), .N, by = .EACHI]$N]
    )
  }
)
ggplot2::autoplot(bm)

在此处输入图片说明

For 10000 rows, the non-equi join is about 10 times faster than the other methods.

As DT is being updated, copy() is used to create a fresh, unmodified copy of DT for each benchmark run.

DT[, count_by_group := vapply(start, function(x) sum(x > end), integer(1)), by = group_id]

To refer to start and end by group, we need to leave the DT$ prefix out.
We use vapply() rather than lapply() because if the right hand side of := is a list, it is interpreted as a list of columns (and since only one column is expected, only the first element, a 0 , is taken into account and recycled).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM