简体   繁体   English

如何降低我的O​​(N ^ 2)C指数函数的时间复杂度?

[英]How to reduce the time complexity of my O(N^2) C-index function?

I have the following function (in Matlab) which will calculate the concordance index for a given set of predictions and observed values: 我有以下函数(在Matlab中),它将计算给定预测集和观察值的一致性索引:

function civalue = CI(predval)
% FUNCTION civalue = CI(predval)
%
% DESCRIPTION: 
% - This function will calculate the concordance index. Not suitable for
% big vectors. O(n^2) time function. 
%
% INPUTS: 
% 'predval' a n-by-2 matrix, where the first column consists of the
% prediction values and the second column the actual label values. 
%
% OUTPUT: 
% 'civalue' the CI-value.

N = 0;
hSum = 0;

for i = 1:size(predval, 1)

    yi_pred = predval(i, 1);
    yi_val = predval(i, 2);
    for j = i+1:size(predval, 1)
        yj_pred = predval(j, 1);
        yj_val = predval(j, 2);
        if yi_val ~= yj_val
            N = N + 1;
            if  (yi_pred < yj_pred && yi_val < yj_val) || (yi_pred > yj_pred && yi_val > yj_val) % Order correct
                hSum = hSum + 1;
            elseif (yi_pred < yj_pred && yi_val > yj_val) || (yi_pred > yj_pred && yi_val < yj_val) % Order opposite 
                hSum = hSum + 0;
            elseif yi_pred == yj_pred % Random
                hSum = hSum + 0.5;
            end
        end
    end

end

civalue = hSum / N;

My function has a time complexity of O(N^2). 我的函数的时间复杂度为O(N ^ 2)。 The idea of the code is to do pairwise comparisons between data points. 代码的想法是在数据点之间进行成对比较。 Any ideas how I could reduce the time complexity of my code? 我有什么想法可以减少代码的时间复杂度?

The idea behind the CI-value or C-index, is to measure how well a prediction model was able to rank data points into correct order. CI值或C指数背后的想法是衡量预测模型能够将数据点排序为正确顺序的程度。 What you give to this function is a set of observed values X and their corresponding predictions Y. The function will do the ranking comparison between data points that have different observed values, because they obviously have a ranking. 您为此函数提供的是一组观测值X及其对应的预测Y.该函数将对具有不同观测值的数据点进行排名比较,因为它们显然具有排名。

For example, lets say you have a two observed values for some variable, eg stock price: P1 = 5$, P2 = 7$ 例如,假设您有一些变量的观察值,例如股票价格:P1 = 5 $,P2 = 7 $

Now we create a model that will try to predict the stock prices. 现在我们创建一个模型,试图预测股票价格。 Lets say we builded our model and tested its ability to predict the stock price and for the two data points P1, P2 it predicted the values Y1 = 5.5$ and Y2 = 8$. 让我们说我们建立了我们的模型并测试了它预测股票价格的能力,对于两个数据点P1,P2,它预测值Y1 = 5.5 $和Y2 = 8 $。

Now you can see that the model got the ORDER correct, P1 < P2 && Y1 < Y2 but not the absolute value. 现在你可以看到模型得到的ORDER正确,P1 <P2 && Y1 <Y2但不是绝对值。 This is useful when we need to make selections between a set of alternatives, eg which stock should I buy that will increase in value most etc. 当我们需要在一组备选方案之间进行选择时,这非常有用,例如,我应该购买哪种库存,哪种库存价值会增加等等。

Thank you for all help! 谢谢你的帮助! Please let me know if you need any more information etc. :) 如果您需要更多信息,请告诉我。:)

Here is the comparison between my own and Martin's implementation: 以下是我自己和Martin的实现之间的比较:

在此输入图像描述

You can significantly improve the run-time by vectorizing the inner loop. 您可以通过向量化内循环来显着改善运行时。 The code below can be optimized further (at the expense of legibility). 下面的代码可以进一步优化(以易读性为代价)。 On my machine, using random input, the code runs about 50x faster and produces the same results. 在我的机器上,使用随机输入,代码运行速度提高约50倍并产生相同的结果。 (Random input is probably a bad test-case, as the == branches will never execute) (随机输入可能是一个糟糕的测试用例,因为==分支永远不会执行)

N = 0;
hSum = 0;
for i = 1:size(predval, 1)

    yi_pred = predval(i, 1);
    yi_val = predval(i, 2);
    yj_pred = predval(i+1:end,1);
    yj_val = predval(i+1:end,2);
    idxs = yi_val ~= yj_val;
    N = N + sum(idxs);

    yj_pred = yj_pred(idxs); % redefined to make the next lines prettier
    yj_val = yj_val(idxs); 
    hSum = hSum + sum((yi_pred < yj_pred & yi_val < yj_val) | ...
        (yi_pred > yj_pred & yi_val > yj_val)); % Order correct
    hSum = hSum + 0.5*sum(yi_pred == yj_pred); % Order random
end

The complexity of the function is still O(n^2), though. 但是,函数的复杂性仍然是O(n ^ 2)。

Assuming that your final goal is to improve the runtime performance and if you have good enough memory to run a vectorized approach, this could be one of those - 假设你的最终目标是提高运行时性能,如果你有足够的内存来运行矢量化方法,那么这可能是其中之一 -

%// Column arrays
c1 = predval(:,1);
c2 = predval(:,2);

%// Get logical arrays of IF conditional statements in the original code
start_cond = bsxfun(@ne,c2,c2.')               %//'# starting condition

%// Rest of the three IF conditionals
case1 = bsxfun(@lt,c1,c1.') & bsxfun(@lt,c2,c2.') | ...
    bsxfun(@gt,c1,c1.') & bsxfun(@gt,c2,c2.')  %//'
case2 = bsxfun(@lt,c1,c1.') & bsxfun(@gt,c2,c2.') | ...
    bsxfun(@gt,c1,c1.') & bsxfun(@lt,c2,c2.')  %//'
case3 = bsxfun(@eq,c1,c1.')                    %//'

%// Get the counts for different cases and finally get the output sum
w1 = start_cond & case1
w2 = start_cond & ~case1 & ~case2 & case3
hSum = sum(w1(:))./2 + sum(w2(:))./4

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM