简体   繁体   English

从多元正态分布中有效地随机抽取

[英]Efficiently randomly drawing from a multivariate normal distribution

Just wondering if anyone has ever encountered the problem where he/she needs to randomly draw from a very high dimensional multivariate normal distribution (say dimension = 10,000), as the rmvnorm function of the mvtnorm package is impractical for that. 只是不知道是否有人曾经遇到他/她需要随机地从一个非常高维多元正态分布绘制问题(说尺寸= 10,000),作为rmvnorm的功能mvtnorm包是不切实际的那个。

I know this article has an Rcpp implementation for the dmvnorm function of the mvtnorm package, so I was wondering if something equivalent exists for rmvnorm ? 我知道这文章有一个Rcpp实现了dmvnorm的功能mvtnorm包,所以我在想,如果等价的东西存在rmvnorm

Here's a quick comparison of mvtnorm::rmvnorm and an Rcpp implementation given here by Ahmadou Dicko. 这里有一个快速的比较mvtnorm::rmvnormRcpp定实现这里通过艾哈迈杜迪科。 The times presented are for 100 draws from a multivariate normal distribution with dimension ranging from 500 to 2500. From the graph below you can probably infer the time required for dimension of 10000. Times include the overhead of generating the random mu vector and the diag matrix, but these are consistent across approaches and are trivial for the dimensions in question (eg 0.2 sec for diag(10000) ). 所呈现的时间是来自多元正态分布的100个绘制,维度范围从500到2500.从下图可以推断维度10000所需的时间。时间包括生成随机mu向量和diag矩阵的开销,但这些方法在各种方法中是一致的,并且对于所讨论的尺寸来说是微不足道的(例如,对于diag(10000)为0.2秒diag(10000) )。

library(Rcpp)
library(RcppArmadillo)
library(inline)
library(mvtnorm)

code <- '
using namespace Rcpp;
int n = as<int>(n_);
arma::vec mu = as<arma::vec>(mu_);
arma::mat sigma = as<arma::mat>(sigma_);
int ncols = sigma.n_cols;
arma::mat Y = arma::randn(n, ncols);
return wrap(arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma));
'

rmvnorm.rcpp <- 
  cxxfunction(signature(n_="integer", mu_="numeric", sigma_="matrix"), code,
              plugin="RcppArmadillo", verbose=TRUE)

rcpp.time <- sapply(seq(500, 5000, 500), function(x) {
  system.time(rmvnorm.rcpp(100, rnorm(x), diag(x)))[3]  
})

mvtnorm.time <- sapply(seq(500, 2500, 500), function(x) {
  system.time(rmvnorm(100, rnorm(x), diag(x)))[3]  
})


plot(seq(500, 5000, 500), rcpp.time, type='o', xlim=c(0, 5000),
     ylim=c(0, max(mvtnorm.time)), xlab='dimension', ylab='time (s)')

points(seq(500, 2500, 500), mvtnorm.time, type='o', col=2)

legend('topleft', legend=c('rcpp', 'mvtnorm'), lty=1, col=1:2, bty='n')

在此输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM