簡體   English   中英

R data.table 的連續行之間的快速余弦距離

[英]R fast cosine distance between consecutive rows of a data.table

如何有效地計算 data.table 的大(約 4m 行)的(幾乎)連續行之間的距離? 我已經概述了我目前的方法,但它非常緩慢。 我的實際數據最多有幾百列。 我需要計算滯后和領先以供將來使用,因此我創建了這些並使用它們來計算距離。

library(data.table)
library(proxy)

set_shift_col <- function(df, shift_dir, shift_num, data_cols, byvars = NULL){
  df[, (paste0(data_cols, "_", shift_dir, shift_num)) := shift(.SD, shift_num, fill = NA, type = shift_dir), byvars, .SDcols = data_cols]
}

set_shift_dist <- function(dt, shift_dir, shift_num, data_cols){
  stopifnot(shift_dir %in% c("lag", "lead"))
  shift_str <- paste0(shift_dir, shift_num)
  dt[, (paste0("dist", "_", shift_str)) := as.numeric(
    proxy::dist(
      rbindlist(list(
        .SD[,data_cols, with=FALSE], 
        .SD[, paste0(data_cols, "_" , shift_str), with=FALSE]
      ), use.names = FALSE), 
      method = "cosine")
  ), 1:nrow(dt)]
}

n <- 10000
test_data <- data.table(a = rnorm(n), b = rnorm(n), c = rnorm(n), d = rnorm(n))

cols <- c("a", "b", "c", "d")

set_shift_col(test_data, "lag", 1, cols)
set_shift_col(test_data, "lag", 2, cols)
set_shift_col(test_data, "lead", 1, cols)
set_shift_col(test_data, "lead", 2, cols)

set_shift_dist(test_data, "lag", 1, cols)

我確信這是一種非常低效的方法,任何建議將不勝感激!

您沒有使用proxy::dist function 中的矢量化效率 - 而不是為每一行調用一次,您可以通過一次調用獲得所需的所有距離。

試試這個替換 function 並比較速度:

set_shift_dist2 <- function(dt, shift_dir, shift_num, data_cols){
  stopifnot(shift_dir %in% c("lag", "lead"))
  shift_str <- paste0(shift_dir, shift_num)
  dt[, (paste0("dist2", "_", shift_str)) := proxy::dist(
    .SD[,data_cols, with=FALSE], 
    .SD[, paste0(data_cols, "_" , shift_str), with=FALSE], 
    method = "cosine", 
    pairwise = TRUE
  )]
}

您也可以在一個 go 中執行此操作,而無需在表中存儲數據副本

test_data[, dist_lag1 := proxy::dist(
  .SD, 
  as.data.table(shift(.SD, 1)), 
  pairwise = TRUE, 
  method = 'cosine'
  ), .SDcols = c('a', 'b', 'c', 'd')]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM