簡體   English   中英

將功能應用於所有成對的行(或列)的有效方法

[英]Efficient way to apply function to all pairs of row (or column)

給定一個矩陣(可以是一個很大的維),我如何有效地計算結果矩陣d ,對於某些給定函數,每個值定義為d[i,j] = fun(x[, c(i,j)]) fun

以下是一個例子

x = matrix(1:30, 5)
d = matrix(,ncol(x), ncol(x)) ## the output matrix
for(i in 1:ncol(x)) ## I use a for loop here, should find a more efficient way
for(j in 1:ncol(x)) 
d[i,j] = sum(apply(x[,c(i,j)], 1, min))

一個sapply循環將略快

sapply(1:NCOL(x), function(i) sapply(1:NCOL(x), function(j){
    sum(apply(x[, c(i, j)], 1, min))
}))
#     [,1] [,2] [,3] [,4] [,5] [,6]
#[1,]   15   15   15   15   15   15
#[2,]   15   40   40   40   40   40
#[3,]   15   40   65   65   65   65
#[4,]   15   40   65   90   90   90
#[5,]   15   40   65   90  115  115
#[6,]   15   40   65   90  115  140

考慮一下工作地點。

您要檢查x所有列對。 對於每一對,您將創建一個n x 2矩陣,並對其應用一些函數。 在許多情況下,例如所示的情況,一些工作將用於移動數據以挑選出那些列並創建那些新矩陣。 (循環開銷很小。)其余的工作將用於應用該函數。 R提供了提高兩者速度的機會:

  1. 當僅讀取數據而不使用函數修改數據時, R具有一些內置的自動優化功能,可使用指針而不是完整副本來引用它們。

  2. 某些函數在應用於簡單(一維)數組時會固有地進行矢量化,但在使用apply*函數或通過循環調用時可能會變慢。

這些為我們提供了一些有關在提高陣列操作速度時在何處查找的指導。 詳細信息取決於fun ,因此讓我們考慮問題中的示例:它計算n x 2數組的每一行中的較小者,並對這些結果求和。 R支持內置的矢量化(非常快)功能pmin ,以計算行最小值。 這建議了以下解決方案:

n <- 50
m <- 100
x <- matrix(runif(n*m), n)
system.time({
  y <- matrix(NA_real_, NCOL(x), NCOL(x))
  for (i in seq_len(NCOL(y)))
    for (j in seq_len(NCOL(y)))
     y2[i,j] <- sum(pmin(x[, i], x[, j]))
})

在最好的情況下,我們知道時序最終將在n為線性,在m二次。 這是該解決方案相對於該線程的另一個答案中推薦的sapply方法所提供的加速效果的實證研究。

圖:作為m和n的函數的加速的光柵圖

該研究是使用Microsoft R Open(3.5.1)在四個Xeon內核上進行的。 對於小m ,相對定時是不確定的,因為這種解決方案幾乎不需要可測量的時間。 請注意,顯示的值是倍數,而不是百分比:因此,例如,對於n = 400列,典型值為30+,這表示該解決方案花費的時間sapply解決方案時間的sapply

模式很清楚: pmin的向量化對於大量的行( n )取得了很大的成就,而R的基礎優化最初對少量的列( m )(小於40左右)產生了很大的不同,但幾乎沒有代表較大的m

本課的內容是,您應該直接努力通過向量化來改善fun的時機,而不必擔心循環開銷。

這是僅計算一半值的基本R解決方案。 這是因為兩個for循環的編碼方式,結果矩陣是對稱的。
我已經定義了一個fun的功能可以應用。

fun <- function(x, i, j) sum(apply(x[, c(i, j)], 1, min))

f1 <- function(x){
  d = matrix(NA, ncol(x), ncol(x))
  for(i in 1:ncol(x)){ ## I use a for loop here, should find a more efficient way
    for(j in 1:ncol(x)) 
      d[i, j] = fun(x, i, j)
  }
  d
}

f2 <- function(x){
  d = matrix(NA, ncol(x), ncol(x))
  for(i in 1:ncol(x)) {
    for(j in i:ncol(x)) d[i, j] = fun(x, i, j)
  }
  d[lower.tri(d)] <- t(d)[lower.tri(t(d))]
  d
}


library(microbenchmark)

n <- 1e3
x = matrix(1:n, 125)

mb <- microbenchmark(
  f1 = f1(x),
  f2 = f2(x)
)
mb
#Unit: milliseconds
# expr       min        lq      mean    median        uq      max neval cld
#   f1 14.117403 14.365764 15.297683 14.633804 15.202872 22.57475   100   b
#   f2  7.964885  8.113796  8.650553  8.252852  8.399395 17.33304   100  a 

這是平均43%的時間增加。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM