简体   繁体   中英

How to speed up my code [includes example] in Matlab?

I want to speed up my code. Currently, I used if statement to do it. However, it can be make faster code if we use the convolution way . However, it only works for simple case (as pairwise neighborhood). Let us define my issue.

I have a matrix I=[1 1 1;2 2 2;2 2 1] which has two label {1,2} . I added the padding as its right side. For each pixel in the I , we can defined a pairwise or triple of neighborhood. We will check base on a rule "if these neighborhood value has same class with the pixel, then set a cost value equal -beta , otherwise set the cost equal beta ".

在此处输入图片说明

For example, let consider the yellow pixel in above figure. Its label is 2. We need to compute total of cost value over possible neighborhood case as show in rightmost side. The value of interesting pixels will be set from label {1,2}. In above figure. I just show first case which set the yellow pixel equals 1. We can have same figure, but set yellow pixel is 2 for next case. My task is that compute the cost function base on the above rule.

This is my code. However, it use if statement. It is so slow. Could you help me to speed up it? I tried to use convolution way but I have no idea how to define a mask for triple of neighborhood. Thank all

function U=compute_gibbs(Imlabel,beta,num_class)
num_class=2;
Imlabel=[1 1 1;2 2 2;2 2 1]
beta=1;
U=zeros([size(Imlabel) num_class]);
Imlabel = padarray(Imlabel,[1 1],'replicate','both');
[row,col] = size(Imlabel);
for ii = 2:row-1        
    for jj = 2:col-1
        for l = 1:num_class
            U(ii-1,jj-1,l)=GibbsEnergy(Imlabel,ii,jj,l,beta);
        end
    end
end
function energy = GibbsEnergy(img,i,j,label,beta)
    % img is the labeled image
    energy = 0;
    if (label == img(i,j)) energy = energy-beta;
        else energy = energy+beta;end        
    % North, south, east and west
    if (label == img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    % diagonal elements
    if (label == img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j+1)) energy = energy-beta;
        else energy = energy+beta;end
    if (label == img(i+1,j-1)) energy = energy-beta;
        else energy = energy+beta;end
     %% Triangle elements
    % Case a 
    if(label==img(i-1,j)&label==img(i-1,j-1)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end
    if(label==img(i,j+1)&&label==img(i+1 ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case b 
    if(label==img(i-1,j-1)&label==img(i,j-1)) energy = energy-beta;
        else energy = energy+beta;end     
     if(label==img(i-1,j)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end  
     if(label==img(i+1,j)&label==img(i+1,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    % Case c   
    if(label==img(i,j-1)&label==img(i+1,j-1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i+1,j)&label==img(i,j+1)) energy = energy-beta;
         else energy = energy+beta;end  
    if(label==img(i-1 ,j)&label==img(i-1,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
    % Case d 
    if(label==img(i,j-1)&label==img(i-1,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i-1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 

    %% Rectangular
    if(label==img(i-1,j-1)&label==img(i,j-1)&label==img(i-1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
    if(label==img(i,j-1)&label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i+1,j)&label==img(i +1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
        else energy = energy+beta;end 
     if(label==img(i-1,j)&label==img(i-1,j+1)&label==img(i ,j+1)) energy = energy-beta;
        else energy = energy+beta;end 

This is one faster way. But it only works for simple case (pairwise neighborhood-first row), while my case includes single, triple...neighborhood

C = double(bsxfun(@eq, Imlabel, permute(1:num_class, [1 3 2])));
C(C == 0) = 0;
C(C == 1) = beta;
%% Replace if statement
mask = zeros(3,3); mask(2,2) = 1;
Cpad = convn(C, mask);
Cpad(Cpad == 0) = 0;

mask2 = ones(3,3); mask2(2,2) = 0;
energy = convn(Cpad, mask2, 'valid');

Here are my attempts at this. I can't really tell if either one going to be faster for you because I'm using Octave rather than MATLAB and the timings can be wildly different. For instance, for loops still take forever in Octave. You'll have to test them out and see how they compare.

Matrix Multiplication

As @AnderBiguri notes in the comments , one way to go is to use matrix multiplication. If you take a 3x3 neighborhood, say

nbr = [0 0 0;
       1 0 0;
       1 1 0];

and you want to know if the top-left element is a 1 , you can perform element-wise multiplication by the mask

mask = [1 0 0;
        0 0 0;
        0 0 0];

result = sum(mask .* nbr);

(I'm taking a shortcut here by assuming that the neighborhood is a binary matrix. When I get to the actual code, I'll simply use nbr == current_class to make this so.)

If the result has the same number of 1 elements as the mask, then you've got a match. In this case, the element-wise multiplication of these two is all zeros, so no match.

Rather than element-wise multiplication followed by summing the elements of the result, we can just make nbr and mask into vectors and use vector multiplication:

m = mask(:).';
n = nbr(:);
result = m * n;

This will give you the same value as the previous result. If you have a matrix of masks, you can multiply it by the neighborhood vector and get all of the results at once. So the first step is to generate the 25 mask vectors:

masks = [
   0   0   0   0   1   0   0   0   0;
   0   0   0   0   0   1   0   0   0;
   0   0   0   1   0   0   0   0   0;
   0   0   0   0   0   0   0   1   0;
   0   1   0   0   0   0   0   0   0;
   1   0   0   0   0   0   0   0   0;
   0   0   0   0   0   0   0   0   1;
   0   0   0   0   0   0   1   0   0;
   0   0   1   0   0   0   0   0   0;
   1   1   0   0   0   0   0   0   0;
   1   0   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   0   0;
   0   0   0   0   0   0   1   1   0;
   0   0   0   0   0   0   0   1   1;
   0   0   0   0   0   1   0   0   1;
   0   0   1   0   0   1   0   0   0;
   0   1   1   0   0   0   0   0   0;
   0   0   0   1   0   0   0   1   0;
   0   0   0   0   0   1   0   1   0;
   0   1   0   1   0   0   0   0   0;
   0   1   0   0   0   1   0   0   0;
   1   1   0   1   0   0   0   0   0;
   0   0   0   1   0   0   1   1   0;
   0   0   0   0   0   1   0   1   1;
   0   1   1   0   0   1   0   0   0];

Now when you multiply masks by the neighborhood, you get all of the results at once. Then you compare the result to the sums of the rows of masks to see which ones match.

result = masks * n;
matches = sum(masks, 2) == result;
match_count = sum(matches);

For each match, we subtract beta from the energy. For each non-match, we add beta , so

possible_matches = 25; %// the number of neighborhood types
energy = -beta * match_count + beta * (possible_matches - match_count);

Now all we have to do is to figure out how to get all of the 3x3 neighborhoods out of our image. Fortunately, MATLAB has the im2col function which does just this. Even better, it only takes the valid neighborhoods of the image, so if it's already padded, you're ready to go.

function G = gibbs(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   col_label = im2col(Imlabel, [3 3], 'sliding');
   target = repmat(sum(masks, 2), [1, m*n]);

   for ii = 1:classcount
      found = masks*(col_label == ii);
      match_count = sum(found == target, 1);
      energy = -beta * match_count + beta*(possible_matches - match_count);
      G(:,:,ii) = reshape(energy, m, n);
   end

end

Look-Up Table

If you look at the Matrix Multiplication solution, it's multiplying the neighborhood of each pixel by the 25 masks. For a 1000 x 1000 image, that's 1000 x 1000 x 25 x 9 = 225M multiplications. But there are only 512 (2^9) possible neighbor configurations. So if we figure out what those 512 configurations are, multiply them by the masks, and sum up the matches, we've got a 512-element look-up table and all we have to do for each neighborhood in the image is calculate its index. Here's how to create the look-up table using masks from above:

possible_neighborhoods = de2bi(0:511, 9).';
found = masks * possible_neighborhoods;
target = repmat(sum(masks, 2), [1, size(found, 2)]);
LUT = sum(found == target, 1);

This is pretty much what we were doing in each loop before, but we're doing it for all possible neighborhoods, which is equivalent to all of the bit patterns for the numbers 0:511 .

Now, instead of a binary vector for each pixel that we multiply by the masks, we want a decimal index into the look-up table. For that we can use conv2 with a kernel that effectively does a binary-to-decimal conversion:

k = [1   8   64;
     2  16  128;
     4  32  256];

or

k = [2^0  2^3  2^6
     2^1  2^4  2^7
     2^2  2^5  2^8];

This will give us values of 0:511 for each pixel, so we add one to get to 1:512 and use that as an index into the look-up table. Here's the full code:

function G = gibbs2(img, beta, classcount)

   masks = [
      0   0   0   0   1   0   0   0   0;
      0   0   0   0   0   1   0   0   0;
      0   0   0   1   0   0   0   0   0;
      0   0   0   0   0   0   0   1   0;
      0   1   0   0   0   0   0   0   0;
      1   0   0   0   0   0   0   0   0;
      0   0   0   0   0   0   0   0   1;
      0   0   0   0   0   0   1   0   0;
      0   0   1   0   0   0   0   0   0;
      1   1   0   0   0   0   0   0   0;
      1   0   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   0   0;
      0   0   0   0   0   0   1   1   0;
      0   0   0   0   0   0   0   1   1;
      0   0   0   0   0   1   0   0   1;
      0   0   1   0   0   1   0   0   0;
      0   1   1   0   0   0   0   0   0;
      0   0   0   1   0   0   0   1   0;
      0   0   0   0   0   1   0   1   0;
      0   1   0   1   0   0   0   0   0;
      0   1   0   0   0   1   0   0   0;
      1   1   0   1   0   0   0   0   0;
      0   0   0   1   0   0   1   1   0;
      0   0   0   0   0   1   0   1   1;
      0   1   1   0   0   1   0   0   0];

   [m,n] = size(img);
   possible_matches = size(masks, 1);
   possible_neighborhoods = de2bi(0:511, 9).';   %'
   found = masks * possible_neighborhoods;
   target = repmat(sum(masks, 2), [1, size(found, 2)]);
   LUT = sum(found == target, 1);
   
   k = [1   8   64;
        2  16  128;
        4  32  256];
        
   Imlabel = padarray(img, [1 1], 'replicate', 'both');

   for ii = 1:classcount
      filterImage = conv2(double(Imlabel == ii), k, 'valid');
      matchImg = LUT(filterImage + 1);
      G(:,:,ii) = -beta * matchImg + beta * (possible_matches - matchImg);
   end
   
end

Since we're doing a lot fewer multiplications for a 1000x1000 image, this approach is about 7x faster than the Matrix Multiplication method on my machine using Octave.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM