[英]Speed up of the calculation of the sum the point-wise difference in R
假設我有兩個數據集。 第一個是:
t1<-sample(1:10,10,replace = T)
t2<-sample(1:10,10,replace = T)
t3<-sample(1:10,10,replace = T)
t4<-sample(11:20,10,replace = T)
t5<-sample(11:20,10,replace = T)
xtrain<-rbind(t1,t2,t3,t4,t5)
xtrain
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
t1 7 3 9 10 4 9 2 1 6 9
t2 5 1 1 6 5 3 10 2 6 3
t3 8 6 9 7 9 2 3 5 1 8
t4 16 18 14 17 19 20 15 15 20 19
t5 13 14 18 13 11 19 13 17 16 14
第二個是:
t6<-sample(1:10,10,replace = T)
t7<-sample(11:20,10,replace = T)
xtest<-rbind(t6,t7)
xtest
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
t6 1 5 8 2 10 2 3 4 8 5
t7 14 18 15 12 17 20 17 13 16 17
我想做的是計算每行xtest
和每行xtrain
之間的距離之和。 例如:
sum((7-1)^2+(3-5)^2+(9-8)^2+.....(9-5)^2)
sum((5-1)^2+(1-5)^2+(1-8)^2+.....(4-5)^2)
...
sum((14-13)^2+(18-14)^2+(15-18)^2+.....(17-14)^2)
我目前擁有的是使用兩個 for 循環(見下文),我認為它不能處理大型數據集:
sumPD<-function(vector1,vector2){
sumPD1<-sum((vector1-vector2)^2)
return(sumPD1)
}
loc<-matrix(NA,nrow=dim(xtrain)[1],ncol=dim(xtest)[1])
for(j in 1:dim(xtest)[1]){
for(i in 1:dim(xtrain)[1]){
loc[i,j]<-sumPD(xtrain[i,],xtest[j,])
}
}
我想就如何修改代碼以提高效率征求建議。 先感謝您! 希望有好的討論!
rdist
package 具有快速計算這些成對距離的功能:
rdist::cdist(xtrain, xtest)^2
Output:
[,1] [,2]
[1,] 65 1029
[2,] 94 1324
[3,] 165 1103
[4,] 1189 213
[5,] 1271 191
一種選擇是outer
f1 <- Vectorize(function(i, j) sumPD(xtrain[i,], xtest[j,]))
loc2 <- outer(seq_len(nrow(xtrain)), seq_len(nrow(xtest)), f1)
identical(loc, loc2)
#[1] TRUE
您可以轉置矩陣,使用向量差異和單個循環:
ftrain <- t(xtrain)
ftest <- t(xtest)
sapply(1:(dim(ftest)[2]),function(i){
colSums((ftrain - ftest[,i])^2)
})
[,1] [,2]
t1 103 1182
t2 125 1262
t3 130 1121
t4 1478 159
t5 1329 142
colSums
非常有效,但是如果您想要更快的速度,請查看那里
這里有兩種簡單的方法。
使用dist
- 將計算比需要更多的距離:
dists <- as.matrix(dist(rbind(xtrain, xtest))^2)
dists <- dists[rownames(xtrain), rownames(xtest)]
dists
t6 t7
t1 140 1179
t2 134 693
t3 119 974
t4 1028 91
t5 1085 44
使用適用於 X 矩陣和 y 向量的簡單自定義函數:
euclid <- function(X,y) colSums((X-y)^2)
dists <- mapply(euclid, list(t(xtrain)), split(xtest, row(xtest)))
dists
[,1] [,2]
t1 140 1179
t2 134 693
t3 119 974
t4 1028 91
t5 1085 44
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.