[英]matlab/octave - Generalized matrix multiplication
我想做一个函数来推广矩阵乘法。 基本上,它应该能够进行标准矩阵乘法,但它应该允许通过任何其他函数更改两个二元运算符product / sum。
目标是在CPU和内存方面尽可能高效。 当然,它总是比A * B效率低,但操作员的灵活性才是最重要的。
A = randi(10, 2, 3);
B = randi(10, 3, 4);
% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))
% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)
% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd
方法1-3的问题在于它们将在使用sum()折叠它们之前生成n个矩阵。 4更好,因为它在bsxfun中执行sum(),但是bsxfun仍然生成n个矩阵(除了它们大部分是空的,只包含一个非零值向量的总和,其余的用0填充以匹配尺寸要求)。
我想要的是像第四种方法,但没有无用的0来节省内存。
任何的想法?
以下是您发布的解决方案稍微更精致的版本,并进行了一些小的改进。
我们检查是否有更多的行而不是列,或者相反,然后通过选择将行与矩阵或矩阵与列相乘(从而进行最少量的循环迭代)来相应地进行乘法。
注意 :即使行数少于列,这可能并不总是最好的策略(按行而不是列); MATLAB数组以内存中的列主要顺序存储的事实使得按行分割更有效,因为元素是连续存储的。 访问行涉及通过步幅遍历元素(这不是缓存友好的 - 考虑空间局部性 )。
除此之外,代码应该处理双/单,实/复,满/稀(以及不可能组合的错误)。 它还尊重空矩阵和零维度。
function C = my_mtimes(A, B, outFcn, inFcn)
% default arguments
if nargin < 4, inFcn = @times; end
if nargin < 3, outFcn = @sum; end
% check valid input
assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
'Expecting function handles.')
% preallocate output matrix
M = size(A,1);
N = size(B,2);
if issparse(A)
args = {'like',A};
elseif issparse(B)
args = {'like',B};
else
args = {superiorfloat(A,B)};
end
C = zeros(M,N, args{:});
% compute matrix multiplication
% http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
if M < N
% concatenation of products of row vectors with matrices
% A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
for m=1:M
%C(m,:) = A(m,:) * B;
%C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
end
else
% concatenation of products of matrices with column vectors
% A*B = [A*b_1 , A*b_2 , ... , A*b_n]
for n=1:N
%C(:,n) = A * B(:,n);
%C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
end
end
end
毫无疑问,该函数在整个过程中都会变慢,但对于较大的大小,它比内置矩阵乘法更糟糕的数量级:
(tic/toc times in seconds)
(tested in R2014a on Windows 8)
size mtimes my_mtimes
____ __________ _________
400 0.0026398 0.20282
600 0.012039 0.68471
800 0.014571 1.6922
1000 0.026645 3.5107
2000 0.20204 28.76
4000 1.5578 221.51
这是测试代码:
sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
n = sz(i); disp(n)
A = rand(n,n);
B = rand(n,n);
tic
C = A*B;
t(i,1) = toc;
tic
D = my_mtimes(A,B);
t(i,2) = toc;
assert(norm(C-D) < 1e-6)
clear A B C D
end
semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight
为了完整性,下面是两种更简单的方法来实现广义矩阵乘法(如果要比较性能,请将my_mtimes
函数的最后一部分替换为其中任何一个)。 我甚至不打算发布他们经过的时间:)
C = zeros(M,N, args{:});
for m=1:M
for n=1:N
%C(m,n) = A(m,:) * B(:,n);
%C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
end
end
另一种方式(使用三重循环):
C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
for n=1:N
for p=1:P
%C(m,n) = C(m,n) + A(m,p)*B(p,n);
%C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
end
end
end
如果你想要提高性能,你将不得不转向C / C ++ MEX文件来减少解释的MATLAB代码的开销。 您仍然可以通过从MEX文件中调用它们来利用优化的BLAS / LAPACK例程(有关示例,请参阅本文的第二部分 )。 MATLAB附带英特尔MKL库,坦率地说,当涉及到英特尔处理器上的线性代数计算时,你无法击败它。
其他人已经在文件交换中提到了一些实现通用矩阵例程作为MEX文件的提交(参见@natan的回答)。 如果将它们与优化的BLAS库链接,这些特别有效。
为什么不利用bsxfun
接受任意函数的能力呢?
C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);
这里
f
是外部函数 (corrresponding在矩阵乘法的情况下, 总结 )。 它应该接受任意大小为m
x n
x p
的3D数组,并沿其列操作以返回1
x m
x p
数组。 g
是内部函数 (对应于产品在矩阵乘法的情况下)。 根据bsxfun
,它应该接受相同大小的两个列向量,或者一个列向量和一个标量作为输入,并作为输出返回与输入相同大小的列向量。 这适用于Matlab。 我没有在Octave测试过。
示例1 :矩阵乘法:
>> f = @sum; %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
28 30
64 69
校验:
>> A*B
ans =
28 30
64 69
例2 :考虑上面两个矩阵
>> f = @(x,y) sum(abs(x)); %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
14.8333 16.1538
5.2500 5.6346
检查:手动计算C(1,2)
:
>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
16.1538
在检查了像bsxfun这样的几个处理函数之后,似乎不可能使用这些函数进行直接矩阵乘法(我的意思是直接的是临时产品没有存储在内存中但是尽快求和,然后是其他总和产品处理),因为它们具有固定大小的输出(或者与输入相同,或者使用bsxfun单例扩展,两个输入的维度的笛卡尔积)。 然而,有可能稍微欺骗Octave(这对于检查输出维度的MatLab不起作用):
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', sparse(1, size(A,1)))
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', zeros(1, size(A,1), 2))(:,:,2)
但是不要使用它们,因为输出值不可靠(Octave可能会破坏甚至删除它们并返回0!)。
所以现在我只是实现了一个半矢量化的版本,这是我的功能:
function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.
if ~exist('inop', 'var')
inop = @times;
end
if ~exist('outop', 'var')
outop = @sum;
end
[n, m] = size(A);
[m2, o] = size(B);
if m2 ~= m
error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end
C = [];
if issparse(A) || issparse(B)
C = sparse(o,n);
else
C = zeros(o,n);
end
A = A';
for i=1:n
C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';
end
使用稀疏矩阵和普通矩阵进行测试:稀疏矩阵(慢3倍)的性能差距远小于普通矩阵(约100倍慢)。
我认为这比bsxfun实现慢,但至少它不会溢出内存:
A = randi(10, 1000);
C = genmtimes(A, A');
如果有人有更好的提供,我仍然在寻找更好的选择!
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.