简体   繁体   English

R dist:计算几对多的距离

[英]R dist: calculating distance of few to many

Say you have some participants and control in a given experiment that are evaluated in three characteristics, something like this: 假设您在给定的实验中有一些参与者和控制者,并通过三个特征对其进行了评估,如下所示:

part_A <- c(3, 5, 4)
part_B <- c(12, 15, 18)
part_C <- c(50, 40, 45)

ctrl_1 <- c(4, 5, 5)
ctrl_2 <- c(1, 0, 4)
ctrl_3 <- c(13, 16, 17)
ctrl_4 <- c(28, 30, 35)
ctrl_5 <- c(51, 43, 44)

I want to find for each participant which control case is the closest match. 我想为每个参与者找到哪个控制案例是最接近的匹配。

If I used the dist() function, I could get it, but it would take a lot of time also calculating the distances between controls, which is useless to me (and in the real data, there are 1000 times more control cases than participant cases). 如果我使用dist()函数,我可以得到它,但是计算控件之间的距离也要花费很多时间,这对我来说是无用的(并且在实际数据中,控件实例比参与者多1000倍。例)。

Is there a way to ask for the distances between each of these elements to each of those elements? 有没有一种方法,要求每一个元素每一元素之间的距离? And something that work for very large data sets? 对于大型数据集有什么用?

In the example above, the result I want is: 在上面的示例中,我想要的结果是:

  Participant Closest_Ctrl
1      part_A       ctrl_1
2      part_B       ctrl_3
3      part_C       ctrl_5

Here is a solution that should be sufficiently fast for a not-too-big number of participants: 对于不太多的参与者,以下解决方案应该足够快:

ctrl <- do.call(cbind, mget(ls(pattern = "ctrl_\\d+")))

dat <- mget(ls(pattern = "part_[[:upper:]+]"))

res <- vapply(dat, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
                FUN.VALUE = character(1))

stack(res)
#  values    ind
#1 ctrl_1 part_A
#2 ctrl_3 part_B
#3 ctrl_5 part_C

If this is too slow I would quickly code it in Rcpp. 如果太慢,我会在Rcpp中快速编码。

Convert input to data frames 将输入转换为数据帧

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))

Generate output 产生输出

# calculate distances
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))

# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts), 
           Closest_Ctrl = names(ctrl)[max.col(-dists)])

#   Participant Closest_Ctrl
# 1      part_A       ctrl_1
# 2      part_B       ctrl_3
# 3      part_C       ctrl_5

Benchmark 基准

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(100, parts, simplify = F))
ctrl <- do.call(cbind, replicate(100, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: milliseconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   305.7324   314.8356   435.3961   324.6116   461.4788   770.3221     5
#  f2() 12359.6995 12831.7995 13567.8296 13616.5216 14244.0836 14787.0438     5

Benchmark 2 基准2

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(10, parts, simplify = F))
ctrl <- do.call(cbind, replicate(10*1000, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: seconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   3.450176   4.211997   4.493805   4.339818   5.154191   5.312844     5
#  f2() 119.120484 124.280423 132.637003 130.858727 131.148630 157.776749     5

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM