簡體   English   中英

matlab / octave - 廣義矩陣乘法

[英]matlab/octave - Generalized matrix multiplication

我想做一個函數來推廣矩陣乘法。 基本上,它應該能夠進行標准矩陣乘法,但它應該允許通過任何其他函數更改兩個二元運算符product / sum。

目標是在CPU和內存方面盡可能高效。 當然,它總是比A * B效率低,但操作員的靈活性才是最重要的。

以下是我在閱讀各種 有趣 線程后可以提出的一些命令:

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd

方法1-3的問題在於它們將在使用sum()折疊它們之前生成n個矩陣。 4更好,因為它在bsxfun中執行sum(),但是bsxfun仍然生成n個矩陣(除了它們大部分是空的,只包含一個非零值向量的總和,其余的用0填充以匹配尺寸要求)。

我想要的是像第四種方法,但沒有無用的0來節省內存。

任何的想法?

以下是您發布的解決方案稍微更精致的版本,並進行了一些小的改進。

我們檢查是否有更多的行而不是列,或者相反,然后通過選擇將行與矩陣或矩陣與列相乘(從而進行最少量的循環迭代)來相應地進行乘法。

A * B

注意 :即使行數少於列,這可能並不總是最好的策略(按行而不是列); MATLAB數組以內存中的列主要順序存儲的事實使得按行分割更有效,因為元素是連續存儲的。 訪問行涉及通過步幅遍歷元素(這不是緩存友好的 - 考慮空間局部性 )。

除此之外,代碼應該處理雙/單,實/復,滿/稀(以及不可能組合的錯誤)。 它還尊重空矩陣和零維度。

function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end

對照

毫無疑問,該函數在整個過程中都會變慢,但對於較大的大小,它比內置矩陣乘法更糟糕的數量級:

        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51

mtimes_vs_mymtimes

這是測試代碼:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight

額外

為了完整性,下面是兩種更簡單的方法來實現廣義矩陣乘法(如果要比較性能,請將my_mtimes函數的最后一部分替換為其中任何一個)。 我甚至不打算發布他們經過的時間:)

C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end

另一種方式(使用三重循環):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end

接下來要嘗試什么?

如果你想要提高性能,你將不得不轉向C / C ++ MEX文件來減少解釋的MATLAB代碼的開銷。 您仍然可以通過從MEX文件中調用它們來利用優化的BLAS / LAPACK例程(有關示例,請參閱本文的第二部分 )。 MATLAB附帶英特爾MKL庫,坦率地說,當涉及到英特爾處理器上的線性代數計算時,你無法擊敗它。

其他人已經在文件交換中提到了一些實現通用矩陣例程作為MEX文件的提交(參見@natan的回答)。 如果將它們與優化的BLAS庫鏈接,這些特別有效。

為什么不利用bsxfun接受任意函數的能力呢?

C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);

這里

  • f外部函數 (corrresponding在矩陣乘法的情況下, 總結 )。 它應該接受任意大小為m x n x p的3D數組,並沿其列操作以返回1 x m x p數組。
  • g內部函數 (對應於產品在矩陣乘法的情況下)。 根據bsxfun ,它應該接受相同大小的兩個列向量,或者一個列向量和一個標量作為輸入,並作為輸出返回與輸入相同大小的列向量。

這適用於Matlab。 我沒有在Octave測試過。


示例1 :矩陣乘法:

>> f = @sum;   %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
    28    30
    64    69

校驗:

>> A*B
ans =
    28    30
    64    69

例2 :考慮上面兩個矩陣

>> f = @(x,y) sum(abs(x));     %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
   14.8333   16.1538
    5.2500    5.6346

檢查:手動計算C(1,2)

>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
   16.1538

沒有深入細節,有一些工具,如mtimesxMMX ,是快速通用矩陣和標量操作例程。 您可以查看他們的代碼並根據您的需求進行調整。 它很可能比matlab的bsxfun更快。

在檢查了像bsxfun這樣的幾個處理函數之后,似乎不可能使用這些函數進行直接矩陣乘法(我的意思是直接的是臨時產品沒有存儲在內存中但是盡快求和,然后是其他總和產品處理),因為它們具有固定大小的輸出(或者與輸入相同,或者使用bsxfun單例擴展,兩個輸入的維度的笛卡爾積)。 然而,有可能稍微欺騙Octave(這對於檢查輸出維度的MatLab不起作用):

C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', sparse(1, size(A,1)))
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', zeros(1, size(A,1), 2))(:,:,2)

但是不要使用它們,因為輸出值不可靠(Octave可能會破壞甚至刪除它們並返回0!)。

所以現在我只是實現了一個半矢量化的版本,這是我的功能:

function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.

if ~exist('inop', 'var')
    inop = @times;
end

if ~exist('outop', 'var')
    outop = @sum;
end

[n, m] = size(A);
[m2, o] = size(B);

if m2 ~= m
    error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end


C = [];
if issparse(A) || issparse(B)
    C = sparse(o,n);
else
    C = zeros(o,n);
end

A = A';
for i=1:n
    C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';

end

使用稀疏矩陣和普通矩陣進行測試:稀疏矩陣(慢3倍)的性能差距遠小於普通矩陣(約100倍慢)。

我認為這比bsxfun實現慢,但至少它不會溢出內存:

A = randi(10, 1000);
C = genmtimes(A, A');

如果有人有更好的提供,我仍然在尋找更好的選擇!

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM