[英]R fast cosine distance between consecutive rows of a data.table
如何有效地計算 data.table 的大(約 4m 行)的(幾乎)連續行之間的距離? 我已經概述了我目前的方法,但它非常緩慢。 我的實際數據最多有幾百列。 我需要計算滯后和領先以供將來使用,因此我創建了這些並使用它們來計算距離。
library(data.table)
library(proxy)
set_shift_col <- function(df, shift_dir, shift_num, data_cols, byvars = NULL){
df[, (paste0(data_cols, "_", shift_dir, shift_num)) := shift(.SD, shift_num, fill = NA, type = shift_dir), byvars, .SDcols = data_cols]
}
set_shift_dist <- function(dt, shift_dir, shift_num, data_cols){
stopifnot(shift_dir %in% c("lag", "lead"))
shift_str <- paste0(shift_dir, shift_num)
dt[, (paste0("dist", "_", shift_str)) := as.numeric(
proxy::dist(
rbindlist(list(
.SD[,data_cols, with=FALSE],
.SD[, paste0(data_cols, "_" , shift_str), with=FALSE]
), use.names = FALSE),
method = "cosine")
), 1:nrow(dt)]
}
n <- 10000
test_data <- data.table(a = rnorm(n), b = rnorm(n), c = rnorm(n), d = rnorm(n))
cols <- c("a", "b", "c", "d")
set_shift_col(test_data, "lag", 1, cols)
set_shift_col(test_data, "lag", 2, cols)
set_shift_col(test_data, "lead", 1, cols)
set_shift_col(test_data, "lead", 2, cols)
set_shift_dist(test_data, "lag", 1, cols)
我確信這是一種非常低效的方法,任何建議將不勝感激!
您沒有使用proxy::dist
function 中的矢量化效率 - 而不是為每一行調用一次,您可以通過一次調用獲得所需的所有距離。
試試這個替換 function 並比較速度:
set_shift_dist2 <- function(dt, shift_dir, shift_num, data_cols){
stopifnot(shift_dir %in% c("lag", "lead"))
shift_str <- paste0(shift_dir, shift_num)
dt[, (paste0("dist2", "_", shift_str)) := proxy::dist(
.SD[,data_cols, with=FALSE],
.SD[, paste0(data_cols, "_" , shift_str), with=FALSE],
method = "cosine",
pairwise = TRUE
)]
}
您也可以在一個 go 中執行此操作,而無需在表中存儲數據副本
test_data[, dist_lag1 := proxy::dist(
.SD,
as.data.table(shift(.SD, 1)),
pairwise = TRUE,
method = 'cosine'
), .SDcols = c('a', 'b', 'c', 'd')]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.