[英]How to speed up my code [includes example] in Matlab?
我想加快我的代碼。 目前,我使用 if 語句來做到這一點。 但是,如果我們使用卷積方式,它可以制作更快的代碼。 但是,它僅適用於簡單情況(如成對鄰域)。 讓我們定義我的問題。
我有一個矩陣I=[1 1 1;2 2 2;2 2 1]
有兩個標簽{1,2}
。 我添加了填充作為其右側。 對於I
中的每個像素,我們可以定義鄰域的成對或三元組。 我們將根據規則檢查“如果這些鄰域值與像素具有相同的類,則設置成本值等於-beta
,否則設置成本等於beta
”。
例如,讓我們考慮上圖中的黃色像素。 它的標簽是 2。我們需要計算可能的鄰域情況的總成本值,如最右側所示。 有趣像素的值將從標簽 {1,2} 設置。 上圖中。 我只展示了第一種情況,將黃色像素設置為 1。我們可以有相同的數字,但在下一種情況下設置黃色像素為 2。 我的任務是根據上述規則計算成本函數。
這是我的代碼。 但是,它使用 if 語句。 它是如此緩慢。 你能幫我加快速度嗎? 我嘗試使用卷積方式,但我不知道如何為鄰域三重定義掩碼。 謝謝大家
function U=compute_gibbs(Imlabel,beta,num_class)
num_class=2;
Imlabel=[1 1 1;2 2 2;2 2 1]
beta=1;
U=zeros([size(Imlabel) num_class]);
Imlabel = padarray(Imlabel,[1 1],'replicate','both');
[row,col] = size(Imlabel);
for ii = 2:row-1
for jj = 2:col-1
for l = 1:num_class
U(ii-1,jj-1,l)=GibbsEnergy(Imlabel,ii,jj,l,beta);
end
end
end
function energy = GibbsEnergy(img,i,j,label,beta)
% img is the labeled image
energy = 0;
if (label == img(i,j)) energy = energy-beta;
else energy = energy+beta;end
% North, south, east and west
if (label == img(i-1,j)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i,j+1)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i+1,j)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i,j-1)) energy = energy-beta;
else energy = energy+beta;end
% diagonal elements
if (label == img(i-1,j-1)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i-1,j+1)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i+1,j+1)) energy = energy-beta;
else energy = energy+beta;end
if (label == img(i+1,j-1)) energy = energy-beta;
else energy = energy+beta;end
%% Triangle elements
% Case a
if(label==img(i-1,j)&label==img(i-1,j-1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i,j-1)&label==img(i+1 ,j)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i,j+1)&&label==img(i+1 ,j+1)) energy = energy-beta;
else energy = energy+beta;end
% Case b
if(label==img(i-1,j-1)&label==img(i,j-1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i-1,j)&label==img(i ,j+1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i+1,j)&label==img(i+1,j+1)) energy = energy-beta;
else energy = energy+beta;end
% Case c
if(label==img(i,j-1)&label==img(i+1,j-1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i+1,j)&label==img(i,j+1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i-1 ,j)&label==img(i-1,j+1)) energy = energy-beta;
else energy = energy+beta;end
% Case d
if(label==img(i,j-1)&label==img(i-1,j)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i-1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
else energy = energy+beta;end
%% Rectangular
if(label==img(i-1,j-1)&label==img(i,j-1)&label==img(i-1 ,j)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i,j-1)&label==img(i+1,j-1)&label==img(i+1 ,j)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i+1,j)&label==img(i +1 ,j+1)&label==img(i,j+1)) energy = energy-beta;
else energy = energy+beta;end
if(label==img(i-1,j)&label==img(i-1,j+1)&label==img(i ,j+1)) energy = energy-beta;
else energy = energy+beta;end
這是一種更快的方法。 但它只適用於簡單的情況(成對的鄰里第一行),而我的情況包括單個、三重......鄰里
C = double(bsxfun(@eq, Imlabel, permute(1:num_class, [1 3 2])));
C(C == 0) = 0;
C(C == 1) = beta;
%% Replace if statement
mask = zeros(3,3); mask(2,2) = 1;
Cpad = convn(C, mask);
Cpad(Cpad == 0) = 0;
mask2 = ones(3,3); mask2(2,2) = 0;
energy = convn(Cpad, mask2, 'valid');
這是我在這方面的嘗試。 我真的不能確定任何一個對你來說是否會更快,因為我使用的是 Octave 而不是 MATLAB,而且時間可能會有很大不同。 例如, for
循環在 Octave 中仍然需要永遠。 你必須測試它們,看看它們如何比較。
正如@AnderBiguri 在評論中指出的那樣,一種方法是使用矩陣乘法。 如果你拿一個 3x3 的鄰域,說
nbr = [0 0 0;
1 0 0;
1 1 0];
並且您想知道左上角的元素是否為1
,您可以通過掩碼執行逐元素乘法
mask = [1 0 0;
0 0 0;
0 0 0];
result = sum(mask .* nbr);
(我在這里假設鄰域是一個二進制矩陣,這是一個捷徑。當我得到實際代碼時,我將簡單地使用nbr == current_class
來實現這一點。)
如果結果與掩碼具有相同數量的1
元素,那么您就匹配了。 在這種情況下,這兩者的逐元素乘法全為零,因此不匹配。
而不是逐元素乘法然后對結果的元素求和,我們可以將nbr
和mask
轉換為向量並使用向量乘法:
m = mask(:).';
n = nbr(:);
result = m * n;
這將為您提供與先前結果相同的值。 如果您有一個掩碼矩陣,則可以將其乘以鄰域向量並立即獲得所有結果。 所以第一步是生成25個掩碼向量:
masks = [
0 0 0 0 1 0 0 0 0;
0 0 0 0 0 1 0 0 0;
0 0 0 1 0 0 0 0 0;
0 0 0 0 0 0 0 1 0;
0 1 0 0 0 0 0 0 0;
1 0 0 0 0 0 0 0 0;
0 0 0 0 0 0 0 0 1;
0 0 0 0 0 0 1 0 0;
0 0 1 0 0 0 0 0 0;
1 1 0 0 0 0 0 0 0;
1 0 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 0 0;
0 0 0 0 0 0 1 1 0;
0 0 0 0 0 0 0 1 1;
0 0 0 0 0 1 0 0 1;
0 0 1 0 0 1 0 0 0;
0 1 1 0 0 0 0 0 0;
0 0 0 1 0 0 0 1 0;
0 0 0 0 0 1 0 1 0;
0 1 0 1 0 0 0 0 0;
0 1 0 0 0 1 0 0 0;
1 1 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 1 0;
0 0 0 0 0 1 0 1 1;
0 1 1 0 0 1 0 0 0];
現在,當您將masks
乘以鄰域時,您會立即獲得所有結果。 然后將結果與masks
行的總和進行比較,以查看哪些匹配。
result = masks * n;
matches = sum(masks, 2) == result;
match_count = sum(matches);
對於每場比賽,我們從能量中減去beta
。 對於每個不匹配,我們添加beta
,所以
possible_matches = 25; %// the number of neighborhood types
energy = -beta * match_count + beta * (possible_matches - match_count);
現在我們要做的就是弄清楚如何從我們的圖像中獲取所有 3x3 鄰域。 幸運的是,MATLAB 有im2col函數可以做到這一點。 更好的是,它只需要圖像的有效鄰域,所以如果它已經被填充,你就可以開始了。
function G = gibbs(img, beta, classcount)
masks = [
0 0 0 0 1 0 0 0 0;
0 0 0 0 0 1 0 0 0;
0 0 0 1 0 0 0 0 0;
0 0 0 0 0 0 0 1 0;
0 1 0 0 0 0 0 0 0;
1 0 0 0 0 0 0 0 0;
0 0 0 0 0 0 0 0 1;
0 0 0 0 0 0 1 0 0;
0 0 1 0 0 0 0 0 0;
1 1 0 0 0 0 0 0 0;
1 0 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 0 0;
0 0 0 0 0 0 1 1 0;
0 0 0 0 0 0 0 1 1;
0 0 0 0 0 1 0 0 1;
0 0 1 0 0 1 0 0 0;
0 1 1 0 0 0 0 0 0;
0 0 0 1 0 0 0 1 0;
0 0 0 0 0 1 0 1 0;
0 1 0 1 0 0 0 0 0;
0 1 0 0 0 1 0 0 0;
1 1 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 1 0;
0 0 0 0 0 1 0 1 1;
0 1 1 0 0 1 0 0 0];
[m,n] = size(img);
possible_matches = size(masks, 1);
Imlabel = padarray(img, [1 1], 'replicate', 'both');
col_label = im2col(Imlabel, [3 3], 'sliding');
target = repmat(sum(masks, 2), [1, m*n]);
for ii = 1:classcount
found = masks*(col_label == ii);
match_count = sum(found == target, 1);
energy = -beta * match_count + beta*(possible_matches - match_count);
G(:,:,ii) = reshape(energy, m, n);
end
end
如果您查看矩陣乘法解決方案,它會將每個像素的鄰域乘以 25 個掩碼。 對於 1000 x 1000 圖像,即1000 x 1000 x 25 x 9 = 225M
乘法。 但是只有512
(2^9) 個可能的鄰居配置。 所以如果我們弄清楚這 512 個配置是什么,將它們乘以掩碼,然后總結匹配,我們就有了一個 512 個元素的查找表,我們需要為圖像中的每個鄰域做的就是計算它的指數。 以下是使用上面的masks
創建查找表的方法:
possible_neighborhoods = de2bi(0:511, 9).';
found = masks * possible_neighborhoods;
target = repmat(sum(masks, 2), [1, size(found, 2)]);
LUT = sum(found == target, 1);
這幾乎是我們之前在每個循環中所做的,但我們對所有可能的鄰域都這樣做,這相當於數字0:511
所有位模式。
現在,我們想要一個十進制索引到查找表中,而不是每個像素乘以掩碼的二進制向量。 為此,我們可以將conv2
與一個有效地進行二進制到十進制轉換的內核一起使用:
k = [1 8 64;
2 16 128;
4 32 256];
or
k = [2^0 2^3 2^6
2^1 2^4 2^7
2^2 2^5 2^8];
這將為我們提供每個像素的0:511
值,因此我們添加一個以獲得1:512
並將其用作查找表的索引。 這是完整的代碼:
function G = gibbs2(img, beta, classcount)
masks = [
0 0 0 0 1 0 0 0 0;
0 0 0 0 0 1 0 0 0;
0 0 0 1 0 0 0 0 0;
0 0 0 0 0 0 0 1 0;
0 1 0 0 0 0 0 0 0;
1 0 0 0 0 0 0 0 0;
0 0 0 0 0 0 0 0 1;
0 0 0 0 0 0 1 0 0;
0 0 1 0 0 0 0 0 0;
1 1 0 0 0 0 0 0 0;
1 0 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 0 0;
0 0 0 0 0 0 1 1 0;
0 0 0 0 0 0 0 1 1;
0 0 0 0 0 1 0 0 1;
0 0 1 0 0 1 0 0 0;
0 1 1 0 0 0 0 0 0;
0 0 0 1 0 0 0 1 0;
0 0 0 0 0 1 0 1 0;
0 1 0 1 0 0 0 0 0;
0 1 0 0 0 1 0 0 0;
1 1 0 1 0 0 0 0 0;
0 0 0 1 0 0 1 1 0;
0 0 0 0 0 1 0 1 1;
0 1 1 0 0 1 0 0 0];
[m,n] = size(img);
possible_matches = size(masks, 1);
possible_neighborhoods = de2bi(0:511, 9).'; %'
found = masks * possible_neighborhoods;
target = repmat(sum(masks, 2), [1, size(found, 2)]);
LUT = sum(found == target, 1);
k = [1 8 64;
2 16 128;
4 32 256];
Imlabel = padarray(img, [1 1], 'replicate', 'both');
for ii = 1:classcount
filterImage = conv2(double(Imlabel == ii), k, 'valid');
matchImg = LUT(filterImage + 1);
G(:,:,ii) = -beta * matchImg + beta * (possible_matches - matchImg);
end
end
由於我們對 1000x1000 圖像執行的乘法要少得多,因此這種方法比我的機器上使用 Octave 的矩陣乘法方法快 7 倍。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.