繁体   English   中英

RCpp会加快基本R功能的评估吗?

[英]Will RCpp speed up the evaluation of basic R functions?

请原谅,但我对Rcpp了解不多,但我想弄清楚是否学习它是为了改善我正在写的包。

我已经编写了一个R封装(应该)有效地随机采用MCMC算法从高维,受约束的空间中采样。 它是(未完成的)位于https://github.com/davidkane9/kmatching

问题是,当我运行一个名为Gelman-Rubin诊断的统计测试时,看看我的MCMC算法是否已经收敛到静止分布,我应该得到R = 1的统计数据,但是我的数字非常高,这基本上告诉了我取样很糟糕,没有人应该使用它。 解决方案是采取更多样本并跳过更多(从1000个中取出1个而不是每100个中取1个)。 然而,这需要很多时间。 如果你想运行一些代码,这是一个例子:

##install the package first
data(lalonde)
matchvars = c("age", "educ", "black")
k = kmatch(x = lalonde, weight.var = "treat", match.var = matchvars, n = 1000, skiplength = 1000, chains = 2, verbose = TRUE)

看看这个Rprof输出我得到的是rnorm%*%占用了大部分时间:

                       total.time total.pct self.time self.pct
"kmatch"                  1453.14    100.00      0.00     0.00
"hitandrun"               1450.18     99.79    128.80     8.86
"%*%"                      757.00     52.09    757.00    52.09
"cat"                      343.18     23.62    329.82    22.70
"rnorm"                    106.34      7.32    103.50     7.12
"mirror"                    35.26      2.43     21.84     1.50
"paste"                     14.02      0.96     14.02     0.96
"stdout"                    13.36      0.92     13.36     0.92
"runif"                     13.32      0.92     13.32     0.92
"/"                         12.82      0.88     12.82     0.88
">"                          7.42      0.51      7.42     0.51
"<"                          6.22      0.43      6.22     0.43
"-"                          5.78      0.40      5.78     0.40
"max"                        5.18      0.36      5.18     0.36
"nchar"                      5.12      0.35      5.12     0.35
"*"                          4.84      0.33      4.84     0.33
"min"                        3.94      0.27      3.94     0.27
"sum"                        3.42      0.24      3.42     0.24
"gelman.diag"                2.90      0.20      0.00     0.00
"=="                         2.86      0.20      2.86     0.20
"ncol"                       2.84      0.20      2.84     0.20
"apply"                      2.72      0.19      0.26     0.02
"+"                          2.48      0.17      2.48     0.17
"FUN"                        2.32      0.16      1.66     0.11
"^"                          2.08      0.14      2.08     0.14
":"                          1.24      0.09      1.24     0.09
"sqrt"                       0.96      0.07      0.96     0.07
"%%"                         0.90      0.06      0.90     0.06
"mean.default"               0.62      0.04      0.62     0.04
"lapply"                     0.40      0.03      0.26     0.02
"("                          0.32      0.02      0.32     0.02
"unlist"                     0.26      0.02      0.00     0.00
"array"                      0.12      0.01      0.02     0.00
"sapply"                     0.12      0.01      0.00     0.00
"matrix"                     0.06      0.00      0.02     0.00
"Null"                       0.04      0.00      0.04     0.00
"print"                      0.04      0.00      0.00     0.00
"unique"                     0.04      0.00      0.00     0.00
"abs"                        0.02      0.00      0.02     0.00
"all"                        0.02      0.00      0.02     0.00
"aperm.default"              0.02      0.00      0.02     0.00
"as.matrix.mcmc"             0.02      0.00      0.02     0.00
"file.exists"                0.02      0.00      0.02     0.00
"list"                       0.02      0.00      0.02     0.00
"print.default"              0.02      0.00      0.02     0.00
"stopifnot"                  0.02      0.00      0.02     0.00
"unique.default"             0.02      0.00      0.02     0.00
"which.min"                  0.02      0.00      0.02     0.00
"<Anonymous>"                0.02      0.00      0.00     0.00
"aperm"                      0.02      0.00      0.00     0.00
"as.mcmc.list"               0.02      0.00      0.00     0.00
"as.mcmc.list.default"       0.02      0.00      0.00     0.00
"data"                       0.02      0.00      0.00     0.00
"mcmc.list"                  0.02      0.00      0.00     0.00
"print.gelman.diag"          0.02      0.00      0.00     0.00
"quantile.default"           0.02      0.00      0.00     0.00
"sort"                       0.02      0.00      0.00     0.00
"sort.default"               0.02      0.00      0.00     0.00
"sort.int"                   0.02      0.00      0.00     0.00
"summary"                    0.02      0.00      0.00     0.00
"summary.default"            0.02      0.00      0.00     0.00

如果我设置verbose = F,则cat会消失,但是%*%需要大约70%的时间。 我想知道是否值得尝试用C ++编写我的代码然后使用RCpp,或者如果因为花费这么多时间的函数是基本函数(已经用C编写)它就不值得我和我只需要忍受它或找到更好的算法。

编辑:根据Rprof,阻止我的一行是hit u = Z %*% r hitandrun u = Z %*% r

## This is the loop that is being run millions of times and taking forever
for(i in 1:(n*skiplength+discard)) {
        tmin<-0;tmax<-0;
        ## runs counts how many times tried to pick a direction, if
        ## too high fail.
        runs = 0
        while(tmin ==0 && tmax ==0) {
          ## r is a random unit vector in with basis in Z
          r <- rnorm(ncol(Z))
          r <- r/sqrt(sum(r^2))

          ## u is a unit vector in the appropriate k-plane pointing in a
          ## random direction Z %*% r is the same as in mirror
          u <- Z%*%r
          c <- y/u
          ## determine intersections of x + t*u with walls
          ## the limits on how far you can go backward and forward
          ## i.e. the maximum and minimum ratio y_i/u_i for negative and positive u.
          tmin <- max(-c[u>0]); tmax <- min(-c[u<0]);
          ## unboundedness
          if(tmin == -Inf || tmax == Inf){
            stop("problem is unbounded")
          }
          ## if stuck on boundary point
          if(tmin==0 && tmax ==0) {
            runs = runs + 1
            if(runs >= 1000) stop("hitandrun found can't find feasible direction, cannot generate points")
          }
        }

        ## chose a point on the line segment
        y <- y + (tmin + (tmax - tmin)*runif(1))*u;

        ## choose a point every 'skiplength' samples
        if(i %% skiplength == 0) {
          X[,index] <- y
          index <- index + 1
        }
        if(verbose) for(j in 1:nchar(str)) cat("\b")
        str <- paste(i)
        if(verbose) cat(str)
      }

它实际上是我在采样循环中进行矩阵乘法的唯一时间,但是我做了数千次,每次采样一百万次采样并抛出99%。

事实上,Rcpp已经被广泛用于此目的:MCMC。 你通常会获得相当不错的速度提升,大约30到50或70。

其中一个早期的软件包是Whit的rcppbugs ,在用他编写的一些类编程之后,他转换为Rcpp以便于使用。 对“Rcpp MCMC”的随意网络搜索将引导您访问几个帖子。

其他作者也使用过Rcpp。 它也位于(R)Stan的内部,因为您真的希望MCMC中固有循环结构尽可能快地运行。 因此编译。

我上周询问了rcpp-devel列表,我将在明天提交的一份简短的R用户组演示中讨论,并且“MCMC”建议或多或少占主导地位。 还介绍了另一个RUG的整个讲话。 我链接到线程,但不知何故,它落入了Gmane的rcpp-devel存档。

总而言之,我会说是的,你确实想考虑在这里使用Rcpp。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM