I have a piece of code searching which lines of a matrix boxes
are equal to a given vector x
. This codes uses the apply
function, and i wonder if it can be optimized more ?
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
seq(0, 1 - 1/10, length = 10)
})))
# can the following line be more optimised ? :
result <- which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
I did not manage to get rid of the apply
function myself but maybe you'll have better ideas than me :)
One option is which(colSums(t(boxes) == x) == ncol(boxes))
.
Vectors are recycled column-wise, so we need to transpose boxes
before comparing to x
with ==
. Then we can pick which
column (transposed row) has a sum of ncol(boxes)
, ie all TRUE
values.
Here's a benchmark for this (possibly not representative) example
Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun())
# Unit: microseconds
# expr min lq mean median uq max neval
# Irnv() 19218.470 20122.2645 24182.2337 21882.8815 24949.1385 66387.719 100
# ICT() 300.308 323.2830 466.0395 342.3595 430.1545 7878.978 100
# RS() 566.564 586.2565 742.4252 617.2315 688.2060 8420.927 100
# RS2() 698.257 772.3090 1017.0427 842.2570 988.9240 9015.799 100
# akrun() 442.667 453.9490 579.9102 473.6415 534.5645 6870.156 100
which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
#[1] 5805
A variation to your answer using mapply
.
which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
#[1] 5805
We can simplify (only reducing the key strokes, see ICT's benchmarks) the above version if boxes
is allowed to be dataframe.
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
#[1] 5805
Benchmarks on my system for various answers on a fresh R session
Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
akrun2 <- function() which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
akrun3 <- function() which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))
library(microbenchmark)
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun(), akrun2(), akrun3())
#Unit: microseconds
# expr min lq mean median uq max neval
#Irnv() 16335.205 16720.8905 18545.0979 17640.7665 18691.234 49036.793 100
#ICT() 195.068 215.4225 444.9047 233.8600 329.288 4635.817 100
#RS() 527.587 577.1160 1344.3033 639.7180 1373.426 36581.216 100
#RS2() 648.996 737.6870 1810.3805 847.9865 1580.952 35263.632 100
#akrun() 384.498 402.1985 761.0542 421.5025 1176.129 4102.214 100
#akrun2() 840.324 853.9825 1415.9330 883.3730 1017.014 34662.084 100
#akrun3() 399.645 459.7685 1186.7605 488.3345 1215.601 38098.927 100
data
set.seed(3251)
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
seq(0, 1 - 1/10, length = 10)
})))
We can also use rowSums
on a replicated 'x' to make the lengths same
which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
Or use the rep
which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
Or with sweep
and rowSums
which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.