简体   繁体   中英

Which lines of a matrix are equal to a certain vector

I have a piece of code searching which lines of a matrix boxes are equal to a given vector x . This codes uses the apply function, and i wonder if it can be optimized more ?

x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
  seq(0, 1 - 1/10, length = 10)
})))

# can the following line be more optimised ? :
result <- which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))

I did not manage to get rid of the apply function myself but maybe you'll have better ideas than me :)

One option is which(colSums(t(boxes) == x) == ncol(boxes)) .

Vectors are recycled column-wise, so we need to transpose boxes before comparing to x with == . Then we can pick which column (transposed row) has a sum of ncol(boxes) , ie all TRUE values.

Here's a benchmark for this (possibly not representative) example

Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){ 
  boxes <- data.frame(boxes)
  which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))


microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun())
# Unit: microseconds
#     expr       min         lq       mean     median         uq       max neval
#   Irnv() 19218.470 20122.2645 24182.2337 21882.8815 24949.1385 66387.719   100
#    ICT()   300.308   323.2830   466.0395   342.3595   430.1545  7878.978   100
#     RS()   566.564   586.2565   742.4252   617.2315   688.2060  8420.927   100
#    RS2()   698.257   772.3090  1017.0427   842.2570   988.9240  9015.799   100
#  akrun()   442.667   453.9490   579.9102   473.6415   534.5645  6870.156   100
which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
#[1] 5805

A variation to your answer using mapply .

which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
#[1] 5805

We can simplify (only reducing the key strokes, see ICT's benchmarks) the above version if boxes is allowed to be dataframe.

boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
#[1] 5805

Benchmarks on my system for various answers on a fresh R session

Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){ 
   boxes <- data.frame(boxes)
   which(rowSums(mapply(`==`, boxes, x)) == length(x))
 }
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
akrun2 <- function() which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
akrun3 <- function() which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))

library(microbenchmark)
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun(), akrun2(), akrun3())


#Unit: microseconds
# expr          min         lq       mean     median        uq       max neval
#Irnv()   16335.205 16720.8905 18545.0979 17640.7665 18691.234 49036.793   100
#ICT()      195.068   215.4225   444.9047   233.8600   329.288  4635.817   100
#RS()       527.587   577.1160  1344.3033   639.7180  1373.426 36581.216   100
#RS2()      648.996   737.6870  1810.3805   847.9865  1580.952 35263.632   100
#akrun()    384.498   402.1985   761.0542   421.5025  1176.129  4102.214   100
#akrun2()   840.324   853.9825  1415.9330   883.3730  1017.014 34662.084   100
#akrun3()   399.645   459.7685  1186.7605   488.3345  1215.601 38098.927   100

data

set.seed(3251)
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
              seq(0, 1 - 1/10, length = 10)
})))

We can also use rowSums on a replicated 'x' to make the lengths same

which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))

Or use the rep

which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))

Or with sweep and rowSums

which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM