简体   繁体   中英

Speedup query for R data.table - can this two-argument function be applied by group more quickly?

Is it possible to use data.table to apply a two-parameter function quickly by group across a data set? On a 1 million row data set, I am finding that calling the simple function defined below is taking over 11 seconds, which is much longer than I would expect for something of this complexity.

The self-contained code below outlines the essentials of what I am trying to do:

# generate data frame - 1 million rows
library(data.table)
set.seed(42)
nn = 1e6
daf = data.frame(aa=sample(1:1000, nn, repl=TRUE),
                 bb=sample(1:1000, nn, repl=TRUE),
                 xx=rnorm(nn),
                 yy=rnorm(nn),
                 stringsAsFactors=FALSE)

# myfunc is the function to apply to each group
myfunc = function(xx, yy) {
  if (max(yy)>1) { 
    return(mean(xx))
  } else {
    return(weighted.mean(yy, ifelse(xx>0, 2, 1)))
  }
}

# running the function takes around 11.5 seconds
system.time({
  dt = data.table(daf, key=c("aa","bb"))
  dt = dt[,myfunc(xx, yy), by=c("aa","bb")]
})

head(dt)
# OUTPUT:
#    aa bb          V1
# 1:  1  2 -1.02605645
# 2:  1  3 -0.49318243
# 3:  1  4  0.02165797
# 4:  1  5  0.40811793
# 5:  1  6 -1.00312393
# 6:  1  7  0.14754417

Is there a way to significantly reduce the time for a function call like this?

I am interested in whether there is a more efficient way to perform the above calculation without completely re-writing the function call, or whether it can only be sped up by breaking apart the function and somehow rewriting it in data.table syntax.

Many thanks in advance for your replies.

Your results:

system.time({
  dt = data.table(daf, key = c("aa","bb"))
  dt = dt[,myfunc(xx, yy), by = c("aa","bb")]
}) # 21.25
dtInitial <- copy(dt)

V1: if NA values does not concern you, you can modify your function like this:

myfunc2 = function(xx, yy) {
  if (max(yy) > 1) { 
    return(mean(xx))
  } else {
    w <- ifelse(xx > 0, 2, 1)
    return(sum((yy * w)[w != 0])/sum(w))
  }
}

system.time({
  dt = data.table(daf, key = c("aa","bb"))
  dtM = dt[, myfunc2(xx, yy), by = c("aa","bb")]
}) # 6.69  
all.equal(dtM, dtInitial)
# [1] TRUE

V2: Also, you can do it faster like this:

system.time({
dt3 <- data.table(daf, key = c("aa","bb"))
dt3[, maxy := max(yy), by = c("aa","bb")]
dt3[, meanx := mean(xx), by = c("aa","bb")]
dt3[, w := ifelse(xx > 0, 2, 1)]
dt3[, wm2 := sum((yy * w)[w != 0])/sum(w), by = c("aa","bb")]
r2 <- dt3[, .(aa, bb, V1 = ifelse(maxy > 1, meanx, wm2))]
r2 <- unique(r2)
}) #2.09 
all.equal(r2, dtInitial)
# [1] TRUE

20 sek vs 2 sek for me


Update:

Or a little bit faster:

system.time({
  dt3 <- data.table(daf, key = c("aa","bb"))
  dt3[, w := ifelse(xx > 0, 2, 1)]
  dt3[, yyw := yy * w]
  r2 <- dt3[, .(maxy = max(yy),
                meanx = mean(xx),
                wm2 = sum(yyw)/sum(w)), 
            , by = c("aa","bb")]
  r2[, V1 := ifelse(maxy > 1, meanx, wm2)]
  r2[, c("maxy", "meanx", "wm2") := NULL]
}) # 1.51

all.equal(r2, dtInitial)
# [1] TRUE

Another solution

system.time({
  dat <- data.table(daf, key = c("aa","bb"))
  dat[, xweight := (xx > 0) * 1 + 1]
  result <- dat[, list(MaxY = max(yy), Mean1 = mean(xx), Mean2 = sum(yy*xweight)/sum(xweight)), keyby=c("aa", "bb")]
  result[, FinalMean := ifelse(MaxY > 1, Mean1, Mean2)]
})

   user  system elapsed 
  1.964   0.059   1.348

I've found a way to gain a further speedup of 8x, which reduces the time down to around 0.2 seconds on my machine. See below. Rather than calculating sum(yyw)/sum(w) directly for each group, which is time-consuming, we instead calculate the quantities sum(yyw) and sum(w) for each group, and only afterwards perform the division. Magic!

system.time({
  dt <- data.table(daf, key = c("aa","bb"))
  dt[, w := 1][xx > 0, w := 2]
  dt[, yyw := yy * w]
  res <- dt[, .(maxy = max(yy), 
                meanx = mean(xx), 
                wm2num = sum(yyw), 
                wm2den = sum(w)),
                by = c("aa","bb")]
  res[, wm2 := wm2num/wm2den]            
  res[, V1 := wm2][maxy > 1, V1 := meanx]

  res[, c("maxy", "meanx", "wm2num", "wm2den", "wm2") := NULL]
}) # 0.19

all.equal(res, dtInitial)
# [1] TRUE

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM