簡體   English   中英

R 中的快速距離計算

[英]Fast distance calculation in R

我正在嘗試計算

1)歐幾里得距離,和

2) 馬氏距離

對於 r 中的一組矩陣。 我一直在這樣做:

v1 <- structure(c(0.508, 0.454, 0, 2.156, 0.468, 0.488, 0.682, 1, 1.832, 
            0.44, 0.928, 0.358, 1, 1.624, 0.484, 0.516, 0.378, 1, 1.512, 
            0.514, 0.492, 0.344, 0, 1.424, 0.508, 0.56, 0.36, 1, 1.384, 0.776, 
            1.888, 0.388, 0, 1.464, 0.952, 0.252, 0.498, 1, 1.484, 0.594, 
            0.256, 0.54, 2, 2.144, 0.402, 0.656, 2.202, 1, 1.696, 0.252), 
          .Dim = c(5L, 10L), 
          .Dimnames = list(NULL, c("KW_1", "KW_2", "KW_3", "KW_4", "KW_5", "KW_6", "KW_7", "KW_8", "KW_9", "KW_10")))

v2 <- structure(c(1.864, 1.864, 1.864, 1.864, 1.864, 1.6, 1.6, 1.6, 
            1.6, 1.6, 1.536, 1.536, 1.536, 1.536, 1.536, 1.384, 1.384, 1.384, 
            1.384, 1.384, 6.368, 6.368, 6.368, 6.368, 6.368, 2.792, 2.792, 
            2.792, 2.792, 2.792, 2.352, 2.352, 2.352, 2.352, 2.352, 2.624, 
            2.624, 2.624, 2.624, 2.624, 1.256, 1.256, 1.256, 1.256, 1.256, 
            1.224, 1.224, 1.224, 1.224, 1.224), 
          .Dim = c(5L, 10L), 
          .Dimnames = list(NULL, c("KW_1", "KW_2", "KW_3", "KW_4", "KW_5", "KW_6", "KW_7", "KW_8", "KW_9", "KW_10")))

L2 <- sqrt(rowSums((v1-v2)^2))  # Euclidean distance for each row

它提供:

[1] 7.132452 7.568359 7.536904 5.448696 7.163580

這是完美的:但我聽說您也可以使用以下形式計算歐幾里得/L2 距離:

在此處輸入圖像描述

我想以這種方式計算我的距離,因為馬氏距離就是這個和協方差矩陣。 看到這個

但是,我還沒有弄清楚如何在 r 中對此進行編碼。 我試過了:

sqrt(crossprod((t(v1)-t(v2))))

sqrt((v1-v2) %*% t(v1-v2))

但他們就是不給我我想要的。 建議?

注意 - 我希望將其作為單個操作來執行,而不是在任何類型的循環中。 它必須非常快,因為我要多次執行數百萬行。 也許這是不可能的。 我願意更改v1v2的格式。

您需要將公式分別應用於每一行,例如:

> sapply(1:nrow(v1), function(i) {
+     q = v1[i, ] - v2[i, ]
+     d = sqrt(t(q) %*% q)
+     d
+ })
[1] 7.132452 7.568359 7.536904 5.448696 7.163580

如果您需要更快的東西,您可以隨時在 C++ 中嘗試相同的操作(代碼改編自此處):

#include <Rcpp.h>

using namespace Rcpp;

double dist2(NumericVector x, NumericVector y){
    double d = sqrt( sum( pow(x - y, 2) ) );
    return d;
}

// [[Rcpp::export]]
NumericVector calc_l2 (NumericMatrix x, NumericMatrix y){
    int out_length = x.nrow();
    NumericVector out(out_length);

    for (int i = 0 ; i < out_length; i++){
        NumericVector v1 = x.row(i);
        NumericVector v2 = y.row(i);
        double d = dist2(v1, v2);
        out(i) = d;
    }
    return (out) ;
}

在 R 中運行:

library(Rcpp)

sourceCpp("calc_L2.cpp")
calc_l2(v1, v2)

如果您內聯 function 調用,Marius 的 Rcpp 代碼會快 10 倍左右,但它仍然與sqrt(rowSums((m1-m2)^2))一樣快:

library(Rcpp)

sourceCpp("r/calc_L2.cpp") # original by Marius

cppFunction('NumericVector calc_l2_inline(NumericMatrix x,NumericMatrix y){
  int nrow=x.nrow();
  NumericVector out(nrow);
  for(int i=0;i<nrow;i++)out(i)=sqrt(sum(pow(x.row(i)-y.row(i),2)));
  return(out);
}')

ncol=10
nrow=1e5
m1=matrix(runif(ncol*nrow),nrow)
m2=matrix(runif(ncol*nrow),nrow)

microbenchmark(times=100,
  rowSums={sqrt(rowSums((m1-m2)^2))},
  `Rfast::rowsums`={sqrt(Rfast::rowsums((m1-m2)^2))},
  Rcpp_original={calc_l2(m1,m2)},
  Rcpp_inlined={calc_l2_inline(m1,m2)},
  sapply_dotproduct={sapply(1:nrow(m1),function(i){q=m1[i,]-m2[i,];sqrt(q%*%q)})},
  sapply_regular={sapply(1:nrow(m1),function(i)sqrt(sum((m1[i,]-m2[i,])^2)))},
  for_loop={o=numeric(nrow);for(i in 1:nrow)o[i]=sqrt(sum((m1[i,]-m2[i,])^2))},
  mapply={r=row(m1);mapply(function(x,y)sqrt(sum((x-y)^2)),split(m1,r),split(m2,r))}
)

結果:

Unit: milliseconds
              expr        min         lq       mean     median         uq         max neval
           rowSums   4.295901   4.708260   5.761508   5.461944   6.496243   10.247036   100
    Rfast::rowsums   3.004092   3.327411   4.135796   3.451392   5.731450    6.877907   100
     Rcpp_original  37.777999  39.480606  51.307351  43.006943  61.729813  176.979826   100
      Rcpp_inlined   4.232740   4.283238   4.379944   4.332177   4.400327    5.462128   100
 sapply_dotproduct 473.272534 538.187874 615.304276 611.288368 669.466721  875.786952   100
    sapply_regular 197.353688 233.303991 275.858154 260.292042 302.541703  536.336035   100
          for_loop 130.624967 153.188579 195.365026 190.774655 219.141935  526.906898   100
            mapply 603.384269 662.399258 717.631411 695.372090 738.274394 1038.938268   100

這是計算向量v到矩陣m的每一行的距離的一種快速方法:

sqrt(rowSums(m^2)+sum(v^2)-2*m%*%as.matrix(v)[,1])

或者,如果您有一個具有m行的矩陣和另一個具有n行的矩陣,那么以下是一種快速計算矩陣中每個行組合之間的m × n距離矩陣的方法:(使用tcrossprod(m1,m2)而不是m1%*%t(m2)使代碼快了大約 1%):

sqrt(outer(rowSums(m1^2),rowSums(m2^2),"+")-2*tcrossprod(m1,m2))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM