简体   繁体   English

如何在 Matlab 中加速我的代码 [包括示例]?

[英]How to speed up my code [includes example] in Matlab?

I want to speed up my code.我想加快我的代码。 Currently, I used if statement to do it.目前,我使用 if 语句来做到这一点。 However, it can be make faster code if we use the convolution way .但是,如果我们使用卷积方式,它可以制作更快的代码。 However, it only works for simple case (as pairwise neighborhood).但是,它仅适用于简单情况(如成对邻域)。 Let us define my issue.让我们定义我的问题。

I have a matrix I=[1 1 1;2 2 2;2 2 1] which has two label {1,2} .我有一个矩阵I=[1 1 1;2 2 2;2 2 1]有两个标签{1,2} I added the padding as its right side.我添加了填充作为其右侧。 For each pixel in the I , we can defined a pairwise or triple of neighborhood.对于I中的每个像素,我们可以定义邻域的成对或三元组。 We will check base on a rule "if these neighborhood value has same class with the pixel, then set a cost value equal -beta , otherwise set the cost equal beta ".我们将根据规则检查“如果这些邻域值与像素具有相同的类,则设置成本值等于-beta ,否则设置成本等于beta ”。

在此处输入图片说明

For example, let consider the yellow pixel in above figure.例如,让我们考虑上图中的黄色像素。 Its label is 2. We need to compute total of cost value over possible neighborhood case as show in rightmost side.它的标签是 2。我们需要计算可能的邻域情况的总成本值,如最右侧所示。 The value of interesting pixels will be set from label {1,2}.有趣像素的值将从标签 {1,2} 设置。 In above figure.上图中。 I just show first case which set the yellow pixel equals 1. We can have same figure, but set yellow pixel is 2 for next case.我只展示了第一种情况,将黄色像素设置为 1。我们可以有相同的数字,但在下一种情况下设置黄色像素为 2。 My task is that compute the cost function base on the above rule.我的任务是根据上述规则计算成本函数。

This is my code.这是我的代码。 However, it use if statement.但是,它使用 if 语句。 It is so slow.它是如此缓慢。 Could you help me to speed up it?你能帮我加快速度吗? I tried to use convolution way but I have no idea how to define a mask for triple of neighborhood.我尝试使用卷积方式,但我不知道如何为邻域三重定义掩码。 Thank all谢谢大家

function U=compute_gibbs(Imlabel,beta,num_class)
num_class=2;
Imlabel=[1 1 1;2 2 2;2 2 1]
beta=1;
U=zeros([size(Imlabel) num_class]);
Imlabel = padarray(Imlabel,[1 1],'replicate','both');
[row,col] = size(Imlabel);
for ii = 2:row-1        
    for jj = 2:col-1
        for l = 1:num_class
            U(ii-1,jj-1,l)=GibbsEnergy(Imlabel,ii,jj,l,beta);
        end
    end
end
function energy = GibbsEnergy(img,i,j,label,beta)
    % img is the labeled image
    energy = 0;
    if (label == img(i,j)) energy = energy-beta;
        else energy = energy+beta;end        
    % North, south, east and west
    if (label == img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    % diagonal elements
    if (label == img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
     %% Triangle elements
    % Case a 
    if(label==img(i-1,j)&label==img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end
    if(label==img(i,j+1)&&label==img(i+1 ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case b 
    if(label==img(i-1,j-1)&label==img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end     
     if(label==img(i-1,j)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end  
     if(label==img(i+1,j)&label==img(i+1,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    % Case c   
    if(label==img(i,j-1)&label==img(i+1,j-1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i+1,j)&label==img(i,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i-1 ,j)&label==img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case d 
    if(label==img(i,j-1)&label==img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i-1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 

    %% Rectangular
    if(label==img(i-1,j-1)&label==img(i,j-1)&label==img(i-1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j)&label==img(i +1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i-1,j)&label==img(i-1,j+1)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 

This is one faster way.这是一种更快的方法。 But it only works for simple case (pairwise neighborhood-first row), while my case includes single, triple...neighborhood但它只适用于简单的情况(成对的邻里第一行),而我的情况包括单个、三重......邻里

C = double(bsxfun(@eq, Imlabel, permute(1:num_class, [1 3 2])));
C(C == 0) = 0;
C(C == 1) = beta;
%% Replace if statement
mask = zeros(3,3); mask(2,2) = 1;
Cpad = convn(C, mask);
Cpad(Cpad == 0) = 0;

mask2 = ones(3,3); mask2(2,2) = 0;
energy = convn(Cpad, mask2, 'valid');

Here are my attempts at this.这是我在这方面的尝试。 I can't really tell if either one going to be faster for you because I'm using Octave rather than MATLAB and the timings can be wildly different.我真的不能确定任何一个对你来说是否会更快,因为我使用的是 Octave 而不是 MATLAB,而且时间可能会有很大不同。 For instance, for loops still take forever in Octave.例如, for循环在 Octave 中仍然需要永远。 You'll have to test them out and see how they compare.你必须测试它们,看看它们如何比较。

Matrix Multiplication矩阵乘法

As @AnderBiguri notes in the comments , one way to go is to use matrix multiplication.正如@AnderBiguri 在评论中指出的那样,一种方法是使用矩阵乘法。 If you take a 3x3 neighborhood, say如果你拿一个 3x3 的邻域,说

nbr = [0 0 0;
       1 0 0;
       1 1 0];

and you want to know if the top-left element is a 1 , you can perform element-wise multiplication by the mask并且您想知道左上角的元素是否为1 ,您可以通过掩码执行逐元素乘法

mask = [1 0 0;
        0 0 0;
        0 0 0];

result = sum(mask .* nbr);

(I'm taking a shortcut here by assuming that the neighborhood is a binary matrix. When I get to the actual code, I'll simply use nbr == current_class to make this so.) (我在这里假设邻域是一个二进制矩阵,这是一个捷径。当我得到实际代码时,我将简单地使用nbr == current_class来实现这一点。)

If the result has the same number of 1 elements as the mask, then you've got a match.如果结果与掩码具有相同数量的1元素,那么您就匹配了。 In this case, the element-wise multiplication of these two is all zeros, so no match.在这种情况下,这两者的逐元素乘法全为零,因此不匹配。

Rather than element-wise multiplication followed by summing the elements of the result, we can just make nbr and mask into vectors and use vector multiplication:而不是逐元素乘法然后对结果的元素求和,我们可以将nbrmask转换为向量并使用向量乘法:

m = mask(:).';
n = nbr(:);
result = m * n;

This will give you the same value as the previous result.这将为您提供与先前结果相同的值。 If you have a matrix of masks, you can multiply it by the neighborhood vector and get all of the results at once.如果您有一个掩码矩阵,则可以将其乘以邻域向量并立即获得所有结果。 So the first step is to generate the 25 mask vectors:所以第一步是生成25个掩码向量:

masks = [
   0   0   0   0   1   0   0   0   0;
   0   0   0   0   0   1   0   0   0;
   0   0   0   1   0   0   0   0   0;
   0   0   0   0   0   0   0   1   0;
   0   1   0   0   0   0   0   0   0;
   1   0   0   0   0   0   0   0   0;
   0   0   0   0   0   0   0   0   1;
   0   0   0   0   0   0   1   0   0;
   0   0   1   0   0   0   0   0   0;
   1   1   0   0   0   0   0   0   0;
   1   0   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   0   0;
   0   0   0   0   0   0   1   1   0;
   0   0   0   0   0   0   0   1   1;
   0   0   0   0   0   1   0   0   1;
   0   0   1   0   0   1   0   0   0;
   0   1   1   0   0   0   0   0   0;
   0   0   0   1   0   0   0   1   0;
   0   0   0   0   0   1   0   1   0;
   0   1   0   1   0   0   0   0   0;
   0   1   0   0   0   1   0   0   0;
   1   1   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   1   0;
   0   0   0   0   0   1   0   1   1;
   0   1   1   0   0   1   0   0   0];

Now when you multiply masks by the neighborhood, you get all of the results at once.现在,当您将masks乘以邻域时,您会立即获得所有结果。 Then you compare the result to the sums of the rows of masks to see which ones match.然后将结果与masks行的总和进行比较,以查看哪些匹配。

result = masks * n;
matches = sum(masks, 2) == result;
match_count = sum(matches);

For each match, we subtract beta from the energy.对于每场比赛,我们从能量中减去beta For each non-match, we add beta , so对于每个不匹配,我们添加beta ,所以

possible_matches = 25; %// the number of neighborhood types
energy = -beta * match_count + beta * (possible_matches - match_count);

Now all we have to do is to figure out how to get all of the 3x3 neighborhoods out of our image.现在我们要做的就是弄清楚如何从我们的图像中获取所有 3x3 邻域。 Fortunately, MATLAB has the im2col function which does just this.幸运的是,MATLAB 有im2col函数可以做到这一点。 Even better, it only takes the valid neighborhoods of the image, so if it's already padded, you're ready to go.更好的是,它只需要图像的有效邻域,所以如果它已经被填充,你就可以开始了。

function G = gibbs(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   col_label = im2col(Imlabel, [3 3], 'sliding');
   target = repmat(sum(masks, 2), [1, m*n]);

   for ii = 1:classcount
      found = masks*(col_label == ii);
      match_count = sum(found == target, 1);
      energy = -beta * match_count + beta*(possible_matches - match_count);
      G(:,:,ii) = reshape(energy, m, n);
   end

end

Look-Up Table查找表

If you look at the Matrix Multiplication solution, it's multiplying the neighborhood of each pixel by the 25 masks.如果您查看矩阵乘法解决方案,它会将每个像素的邻域乘以 25 个掩码。 For a 1000 x 1000 image, that's 1000 x 1000 x 25 x 9 = 225M multiplications.对于 1000 x 1000 图像,即1000 x 1000 x 25 x 9 = 225M乘法。 But there are only 512 (2^9) possible neighbor configurations.但是只有512 (2^9) 个可能的邻居配置。 So if we figure out what those 512 configurations are, multiply them by the masks, and sum up the matches, we've got a 512-element look-up table and all we have to do for each neighborhood in the image is calculate its index.所以如果我们弄清楚这 512 个配置是什么,将它们乘以掩码,然后总结匹配,我们就有了一个 512 个元素的查找表,我们需要为图像中的每个邻域做的就是计算它的指数。 Here's how to create the look-up table using masks from above:以下是使用上面的masks创建查找表的方法:

possible_neighborhoods = de2bi(0:511, 9).';
found = masks * possible_neighborhoods;
target = repmat(sum(masks, 2), [1, size(found, 2)]);
LUT = sum(found == target, 1);

This is pretty much what we were doing in each loop before, but we're doing it for all possible neighborhoods, which is equivalent to all of the bit patterns for the numbers 0:511 .这几乎是我们之前在每个循环中所做的,但我们对所有可能的邻域都这样做,这相当于数字0:511所有位模式。

Now, instead of a binary vector for each pixel that we multiply by the masks, we want a decimal index into the look-up table.现在,我们想要一个十进制索引到查找表中,而不是每个像素乘以掩码的二进制向量。 For that we can use conv2 with a kernel that effectively does a binary-to-decimal conversion:为此,我们可以将conv2与一个有效地进行二进制到十进制转换的内核一起使用:

k = [1   8   64;
     2  16  128;
     4  32  256];

or

k = [2^0  2^3  2^6
     2^1  2^4  2^7
     2^2  2^5  2^8];

This will give us values of 0:511 for each pixel, so we add one to get to 1:512 and use that as an index into the look-up table.这将为我们提供每个像素的0:511值,因此我们添加一个以获得1:512并将其用作查找表的索引。 Here's the full code:这是完整的代码:

function G = gibbs2(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   possible_neighborhoods = de2bi(0:511, 9).';   %'
   found = masks * possible_neighborhoods;
   target = repmat(sum(masks, 2), [1, size(found, 2)]);
   LUT = sum(found == target, 1);
   
   k = [1   8   64;
        2  16  128;
        4  32  256];
        
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   for ii = 1:classcount
      filterImage = conv2(double(Imlabel == ii), k, 'valid');
      matchImg = LUT(filterImage + 1);
      G(:,:,ii) = -beta * matchImg + beta * (possible_matches - matchImg);
   end
   
end

Since we're doing a lot fewer multiplications for a 1000x1000 image, this approach is about 7x faster than the Matrix Multiplication method on my machine using Octave.由于我们对 1000x1000 图像执行的乘法要少得多,因此这种方法比我的机器上使用 Octave 的矩阵乘法方法快 7 倍。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM