[英]Efficient way to apply function to all pairs of row (or column)
給定一個矩陣(可以是一個很大的維),我如何有效地計算結果矩陣d
,對於某些給定函數,每個值定義為d[i,j] = fun(x[, c(i,j)])
fun
以下是一個例子
x = matrix(1:30, 5)
d = matrix(,ncol(x), ncol(x)) ## the output matrix
for(i in 1:ncol(x)) ## I use a for loop here, should find a more efficient way
for(j in 1:ncol(x))
d[i,j] = sum(apply(x[,c(i,j)], 1, min))
一個sapply
循環將略快
sapply(1:NCOL(x), function(i) sapply(1:NCOL(x), function(j){
sum(apply(x[, c(i, j)], 1, min))
}))
# [,1] [,2] [,3] [,4] [,5] [,6]
#[1,] 15 15 15 15 15 15
#[2,] 15 40 40 40 40 40
#[3,] 15 40 65 65 65 65
#[4,] 15 40 65 90 90 90
#[5,] 15 40 65 90 115 115
#[6,] 15 40 65 90 115 140
考慮一下工作地點。
您要檢查x
所有列對。 對於每一對,您將創建一個n x 2矩陣,並對其應用一些函數。 在許多情況下,例如所示的情況,一些工作將用於移動數據以挑選出那些列並創建那些新矩陣。 (循環開銷很小。)其余的工作將用於應用該函數。 R
提供了提高兩者速度的機會:
當僅讀取數據而不使用函數修改數據時, R
具有一些內置的自動優化功能,可使用指針而不是完整副本來引用它們。
某些函數在應用於簡單(一維)數組時會固有地進行矢量化,但在使用apply*
函數或通過循環調用時可能會變慢。
這些為我們提供了一些有關在提高陣列操作速度時在何處查找的指導。 詳細信息取決於fun
,因此讓我們考慮問題中的示例:它計算n x 2數組的每一行中的較小者,並對這些結果求和。 R
支持內置的矢量化(非常快)功能pmin
,以計算行最小值。 這建議了以下解決方案:
n <- 50
m <- 100
x <- matrix(runif(n*m), n)
system.time({
y <- matrix(NA_real_, NCOL(x), NCOL(x))
for (i in seq_len(NCOL(y)))
for (j in seq_len(NCOL(y)))
y2[i,j] <- sum(pmin(x[, i], x[, j]))
})
在最好的情況下,我們知道時序最終將在n
為線性,在m
二次。 這是該解決方案相對於該線程的另一個答案中推薦的sapply
方法所提供的加速效果的實證研究。
該研究是使用Microsoft R Open(3.5.1)在四個Xeon內核上進行的。 對於小m
,相對定時是不確定的,因為這種解決方案幾乎不需要可測量的時間。 請注意,顯示的值是倍數,而不是百分比:因此,例如,對於n = 400列,典型值為30+,這表示該解決方案花費的時間sapply
解決方案時間的sapply
。
模式很清楚: pmin
的向量化對於大量的行( n
)取得了很大的成就,而R
的基礎優化最初對少量的列( m
)(小於40左右)產生了很大的不同,但幾乎沒有代表較大的m
。
本課的內容是,您應該直接努力通過向量化來改善fun
的時機,而不必擔心循環開銷。
這是僅計算一半值的基本R解決方案。 這是因為兩個for
循環的編碼方式,結果矩陣是對稱的。
我已經定義了一個fun
的功能可以應用。
fun <- function(x, i, j) sum(apply(x[, c(i, j)], 1, min))
f1 <- function(x){
d = matrix(NA, ncol(x), ncol(x))
for(i in 1:ncol(x)){ ## I use a for loop here, should find a more efficient way
for(j in 1:ncol(x))
d[i, j] = fun(x, i, j)
}
d
}
f2 <- function(x){
d = matrix(NA, ncol(x), ncol(x))
for(i in 1:ncol(x)) {
for(j in i:ncol(x)) d[i, j] = fun(x, i, j)
}
d[lower.tri(d)] <- t(d)[lower.tri(t(d))]
d
}
library(microbenchmark)
n <- 1e3
x = matrix(1:n, 125)
mb <- microbenchmark(
f1 = f1(x),
f2 = f2(x)
)
mb
#Unit: milliseconds
# expr min lq mean median uq max neval cld
# f1 14.117403 14.365764 15.297683 14.633804 15.202872 22.57475 100 b
# f2 7.964885 8.113796 8.650553 8.252852 8.399395 17.33304 100 a
這是平均43%的時間增加。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.