简体   繁体   English

如何使用 R 有效地将每一行拆分为测试和训练子集?

[英]How to efficiently split each row into test and train subsets using R?

I have a data table that provides the length and composition of given vectors for example:我有一个数据表,提供给定向量的长度和组成,例如:

set.seed(1)

dt = data.table(length = c(100, 150),
                n_A = c(30, 30), 
                n_B = c(20, 100), 
                n_C = c(50, 20))

I need to randomly split each vector into two subsets with 80% and 20% of observations respectively.我需要将每个向量随机分成两个子集,分别具有 80% 和 20% 的观察值。 I can currently do this using a for loop.我目前可以使用 for 循环来做到这一点。 For example:例如:

dt_80_list <- list() # create output lists
dt_20_list <- list()

for (i in 1:nrow(dt)){ # for each row in the data.table
  
  sample_vec <- sample( c(   rep("A", dt$n_A[i]), # create a randomised vector with the given nnumber of each component. 
                             rep("B", dt$n_B[i]),
                             rep("C", dt$n_C[i]) ) )
  
  sample_vec_80 <- sample_vec[1:floor(length(sample_vec)*0.8)] # subset 80% of the vector
  
  dt_80_list[[i]] <- data.table(   length = length(sample_vec_80), # count the number of each component in the subset and output to list
                         n_A = length(sample_vec_80[which(sample_vec_80 == "A")]),
                         n_B = length(sample_vec_80[which(sample_vec_80 == "B")]),
                         n_C = length(sample_vec_80[which(sample_vec_80 == "C")])
  )
  
  dt_20_list[[i]] <- data.table(   length = dt$length[i] - dt_80_list[[i]]$length, # subtract the number of each component in the 80% to identify the number in the 20%
                         n_A = dt$n_A[i] - dt_80_list[[i]]$n_A,
                         n_B = dt$n_B[i] - dt_80_list[[i]]$n_B,
                         n_C = dt$n_C[i] - dt_80_list[[i]]$n_C
  )
}
dt_80 <- do.call("rbind", dt_80_list) # collapse lists to output data.tables
dt_20 <- do.call("rbind", dt_20_list)

However, the dataset I need to apply this to is very large, and this is too slow.但是,我需要将其应用于的数据集非常大,这太慢了。 Does anyone have any suggestions for how I could improve performance?有没有人对我如何提高性能有任何建议?

Thanks.谢谢。

(I assumed your dataset consists of many more rows (but only a few colums).) (我假设您的数据集包含更多行(但只有几列)。)

Here's a version I came up with, with mainly three changes这是我想出来的一个版本,主要有三个变化

  • use .N and by= to count the number of "A","B","C" drawn in each row使用.Nby=计算每行绘制的“A”、“B”、“C”的数量
  • use the size argument in samplesample使用 size 参数
  • join the original dt and dt_80 to calculate dt_20 without a for-loop加入原始dtdt_80计算dt_20无需 for 循环
## draw training data
dt_80 <- dcast(
      dt[,row:=1:nrow(dt)
       ][, .(draw=sample(c(rep("A80",n_A),
                           rep("B80",n_B),
                           rep("C80",n_C)),
                         size=.8*length)  )
         , by=row
       ][,.N,
         by=.(row,draw)],
  row~draw,value.var="N")[,length80:=A80+B80+C80]

## draw test data
dt_20 <- dt[dt_80,
            .(A20=n_A-A80,
              B20=n_B-B80,
              C20=n_C-C80),on="row"][,length20:=A20+B20+C20]

There is probably still room for optimization, but I hope it already helps :)可能还有优化的空间,但我希望它已经有所帮助:)

EDIT编辑

Here I add my initial first idea, I did not post this because the code above is much faster.在这里,我添加了我最初的第一个想法,我没有发布这个,因为上面的代码要快得多。 But this one might be more memory-efficient which seems crucial in your case.但这可能更节省内存,这在您的情况下似乎至关重要。 So, even if you already have a working solution, this might be of interest...因此,即使您已经有了一个可行的解决方案,这也可能会引起您的兴趣...

library(data.table)
library(Rfast)

## add row numbers
dt[,row:=1:nrow(dt)]

## sampling function
sampfunc <- function(n_A,n_B,n_C){ 
  draw <- sample(c(rep("A80",n_A),
                   rep("B80",n_B),
                   rep("C80",n_C)),
                 size=.8*(n_A+n_B+n_C))
  out <- Rfast::Table(draw)
  return(as.list(out))
}

## draw training data
dt_80 <- dt[,sampfunc(n_A,n_B,n_C),by=row]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM