[英]Why is standard R median function so much slower than a simple C++ alternative?
我在C++
中实现了以下中位数,并通过Rcpp
在R
使用它:
// [[Rcpp::export]]
double median2(std::vector<double> x){
double median;
size_t size = x.size();
sort(x.begin(), x.end());
if (size % 2 == 0){
median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
}
else {
median = x[size / 2];
}
return median;
}
如果我以后比较内建R中位数功能标准的表现,我得到通过以下结果microbenchmark
> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
expr min lq mean median uq max neval
median(x) 25.469 26.990 34.96888 28.130 29.081 518.126 100
median2(x) 1.140 1.521 2.47486 1.901 2.281 47.897 100
为什么标准中值函数要慢得多? 这不是我所期望的......
正如@joran 所指出的,您的代码非常专业,一般来说,不太通用的函数、算法等通常性能更高。 看看median.default
:
median.default
# function (x, na.rm = FALSE)
# {
# if (is.factor(x) || is.data.frame(x))
# stop("need numeric data")
# if (length(names(x)))
# names(x) <- NULL
# if (na.rm)
# x <- x[!is.na(x)]
# else if (any(is.na(x)))
# return(x[FALSE][NA])
# n <- length(x)
# if (n == 0L)
# return(x[FALSE][NA])
# half <- (n + 1L)%/%2L
# if (n%%2L == 1L)
# sort(x, partial = half)[half]
# else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }
有几种操作可以适应缺失值的可能性,这些操作肯定会影响函数的整体执行时间。 由于您的函数不会复制此行为,因此可以消除大量计算,但因此不会为具有缺失值的向量提供相同的结果:
median(c(1, 2, NA))
#[1] NA
median2(c(1, 2, NA))
#[1] 2
一对夫妇的其他因素,可能不会有尽可能多的为处理效果NA
S,但值得指出的:
median
,连同它使用的一些函数,都是 S3 泛型,所以在方法分派上花费了少量时间median
不仅仅适用于整数和数字向量; 它还将处理Date
、 POSIXt
,可能还有一堆其他类,并正确保留属性:median(Sys.Date() + 0:4)
#[1] "2016-01-15"
median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"
编辑:我应该提到下面的函数将导致原始向量被排序,因为NumericVector
是代理对象。 如果您想避免这种情况,您可以Rcpp::clone
输入向量并对Rcpp::clone
进行操作,或者使用您的原始签名(带有std::vector<double>
),这在从SEXP
的转换中隐式需要一个副本到std::vector
。
另请注意,您可以使用NumericVector
而不是std::vector<double>
来节省更多时间:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
std::size_t size = x.size();
std::sort(x.begin(), x.end());
if (size % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
return x[size / 2];
}
microbenchmark::microbenchmark(
median(x),
median2(x),
cpp_med(x),
times = 200L
)
# Unit: microseconds
# expr min lq mean median uq max neval
# median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810 200
# median2(x) 6.474 7.9665 13.90126 11.0570 14.844 151.817 200
# cpp_med(x) 5.737 7.4285 11.25318 9.0270 13.405 52.184 200
Yakk 在上面的评论中提出了一个很好的观点 - 也由 Jerry Coffin 详细阐述 - 关于进行完整排序的低效率。 这是使用std::nth_element
的重写,以更大的向量为基准:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
Rcpp::NumericVector x = Rcpp::clone(xx);
std::size_t n = x.size() / 2;
std::nth_element(x.begin(), x.begin() + n, x.end());
if (x.size() % 2) return x[n];
return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}
set.seed(123)
xx <- rnorm(10e5)
all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))
microbenchmark::microbenchmark(
cpp_med2(xx), median2(xx),
median(xx), times = 200L
)
# Unit: milliseconds
# expr min lq mean median uq max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161 33.92103 200
# median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301 200
# median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939 200
[这与其说是对您实际提出的问题的回答,不如说是一种扩展评论。]
甚至您的代码也可能有重大改进。 特别是,即使您只关心一两个元素,您也在对整个输入进行排序。
您可以使用std::nth_element
而不是std::sort
将其从 O(n log n) 更改为 O(n) 。 在偶数个元素的情况下,您通常希望使用std::nth_element
查找中间之前的元素,然后使用std::min_element
查找紧接其后的元素——但std::nth_element
也分区输入项,所以std::min_element
只需要在nth_element
之后中间上方的项上运行,而不是整个输入数组。 也就是说,在 nth_element 之后,你会得到这样的情况:
std::nth_element
的复杂度是“平均线性的”,并且(当然) std::min_element
也是线性的,所以整体复杂度是线性的。
因此,对于简单的情况(奇数个元素),您会得到如下结果:
auto pos = x.begin() + x.size()/2;
std::nth_element(x.begin(), pos, x.end());
return *pos;
...对于更复杂的情况(偶数个元素):
std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;
从这里可以期望max_element( ForwardIt first, ForwardIt last )
提供从第一个到最后一个的最大值,但是通过这样做: return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.
x.begin() + n
元素似乎被排除在计算之外。 为什么会有这种差异?
例如cpp_med2({6, 2, 1, 5, 3, 4})
产生x={2, 1, 3, 4, 5, 6}
其中:
n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3
所以cpp_med2({6, 2, 1, 5, 3, 4})
返回 (4+3)/2=3.5 这是正确的中位数。 但是为什么*std::max_element(x.begin(), x.begin() + n)
等于 3 而不是 4? 该函数实际上似乎排除了最大值计算中的最后一个元素 (4)。
已解决(我认为):在:
查找范围 [first, last) 中最大的元素
)
关闭手段最后被排除在计算之外。 那是对的吗?
此致
我不确定您指的是什么“标准”实现。
无论如何:如果有一个,作为标准库的一部分,肯定不允许更改向量中元素的顺序(如您的实现那样),因此它肯定必须在副本上工作。
创建此副本需要时间和 CPU(以及大量内存),这会影响运行时间。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.