简体   繁体   中英

Efficient and fast application of a function to 3D arrays in R

I have a very large 3D array (say 100 x 100 x 10) that I would like to apply a function over for pairwise comparisons. I've tried a number of solutions, using data.table, mapply, etc. I'm maybe naively hoping for faster speedups, and am considering just doing this with C++/Rcpp. But before doing that, I thought I'd see if anyone is aware of a more elegant / faster solution to this problem? Many thanks!

Example code in R. For this smaller dimension version of what I'm wanting to apply this to, mapply() is a little faster than data.table

m <- 20 
n <- 10 # number of data points per row/col combination

R <- array(runif(n*m*m), dim=c(m,m,n)) # 3D array to apply function over
grid <- expand.grid(A = 1:m, B = 1:m, C = 1:m, D = 1:m) # array indices (used as args below)

#function to do basic correlations between R[1,2,] and R[1,10,]
ss2 <- function(a,b,c,d) {
  rho = cor(R[a, b, ], R[c, d, ])
}

#solution with data.table
dt <- setDT(grid) # convert from df -> dt
sol_1 <- dt[, ss2(A, B,C,D), by = seq_len(nrow(dt))]

#solution with mapply
sol_2 <- mapply(ss2, grid$A, grid$B, grid$C, grid$D)

I tried this with mapply(), data.table(). I've also tried using a parellelized version of apply() (parApply, https://dept.stat.lsa.umich.edu/~jerrick/courses/stat701/notes/parallel.html )

UPDATE: cora from the Rfast package gives further performance improvements.

By reshaping the array, we can use cor directly for a ~2K times speedup:

library(data.table)
library(Rfast)

m <- 20
n <- 10 # number of data points per row/col combination

R <- array(runif(n*m*m), dim=c(m,m,n)) # 3D array to apply function over
grid <- expand.grid(A = 1:m, B = 1:m, C = 1:m, D = 1:m)
ss2 <- function(a,b,c,d) rho = cor(R[a, b, ], R[c, d, ])
dt <- setDT(grid)

microbenchmark::microbenchmark(
  sol_1 = dt[, ss2(A, B, C, D), by = seq_len(nrow(dt))][[2]],
  sol_2 = mapply(ss2, grid$A, grid$B, grid$C, grid$D),
  sol_3 = c(cor(t(matrix(R, m*m, n)))),
  sol_4 = c(cora(t(matrix(R, m*m, n)))),
  check = "equal",
  times = 10
)
#> Unit: microseconds
#>   expr       min        lq       mean    median        uq       max neval
#>  sol_1 2101327.2 2135311.0 2186922.33 2178526.6 2247049.6 2301429.5    10
#>  sol_2 2255828.9 2266427.5 2306180.23 2287911.0 2321609.6 2471711.7    10
#>  sol_3    1203.8    1222.2    1244.75    1236.1    1243.9    1343.5    10
#>  sol_4     922.6     945.8     952.68     951.9     955.8     988.8    10

Timing the full 100 x 100 x 10 array:

m <- 100L
n <- 10L
R <- array(runif(n*m*m), dim=c(m,m,n))

microbenchmark::microbenchmark(
  sol_3 = c(cor(t(matrix(R, m*m, n)))),
  sol_4 = c(cora(t(matrix(R, m*m, n)))),
  check = "equal",
  times = 10
)
#> Unit: milliseconds
#>   expr       min        lq     mean   median       uq      max neval
#>  sol_3 1293.0739 1298.4997 1466.546 1503.453 1513.746 1902.802    10
#>  sol_4  879.8659  892.2699 1058.064 1055.668 1143.767 1300.282    10

Note that filling by column then transposing tends to be slightly faster than filling by row in this case . Also note that ss2 and grid are no longer needed.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM