繁体   English   中英

如何忽略矩阵乘法中的零?

[英]How to ignore zeros in matrix multiplications?

假设我有一个10000 x 10000矩阵W,随机数,两个10000暗淡矢量U和V,U中有随机数,V用零填充。 使用numpy或pytorch,计算U @ W和V @ W需要相同的时间。 我的问题是,有没有一种方法可以优化矩阵乘法,使其在计算过程中跳过或忽略零,所以像V @ W这样的东西会更快地计算出来?

import numpy as np
W = np.random.rand(10000, 10000)

U = np.random.rand(10000)
V = np.zeros(10000)

y1 = U @ W
y2 = V @ W
# computing y2 should take less amount of time than y1 since it always returns zero vector.

您可以使用scipy.sparse类来提高性能,但这完全取决于矩阵。 例如,使用V作为稀疏矩阵获得的性能将是很好的。 通过将U转换为稀疏矩阵而获得的结果不会很大,或者实际上可能会降低性能(在U实际上是密集的情况下)。

import numpy as np
import scipy.sparse as sps

W = np.random.rand(10000, 10000)
U = np.random.rand(10000)
V = np.zeros(10000)

%timeit U @ W
125 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit V @ W
128 ms ± 6.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Vsp = sps.csr_matrix(V)
Usp = sps.csr_matrix(U)
Wsp = sps.csr_matrix(W)

%timeit Vsp.dot(Wsp)
1.34 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 
%timeit Vsp @ Wsp
1.39 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit Usp @ Wsp
2.37 s ± 84.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

正如您所看到的,对于V @ W使用稀疏方法有一个重大改进,但实际上您降低了U @ W性能,因为U或W中的条目都不为零。

In [274]: W = np.random.rand(10000, 10000) 
     ...:  
     ...: U = np.random.rand(10000) 
     ...: V = np.zeros(10000)                                                                            
In [275]: timeit U@W                                                                                     
125 ms ± 263 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [276]: timeit V@W                                                                                     
153 ms ± 18.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

现在考虑一个案例,其中V 100个元素是非零(1s)。 稀疏实现可能是:

In [277]: Vdata=np.ones(100); Vind=np.arange(0,10000,100)                                                
In [278]: Vind.shape                                                                                     
Out[278]: (100,)
In [279]: timeit Vdata@W[Vind,:]                                                                         
4.99 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我此时有点意外,认为W的索引可以抵消乘法时间。

让我们改变V来验证结果:

In [280]: V[Vind]=1                                                                                      
In [281]: np.allclose(V@W, Vdata@W[Vind,:])  

如果我必须先找到非零元素怎么办:

In [282]: np.allclose(np.where(V),Vind)                                                                  
Out[282]: True
In [283]: timeit idx=np.where(V); V[idx]@W[idx,:]                                                        
5.07 ms ± 77.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

W的大小,特别是第二维可能是这个加速的重要因素。 在这些大小下,内存管理可以像原始乘法一样影响速度。

===

在这种情况下, sparse比我预期的要好(其他测试表明我需要稀疏度在1%左右才能获得时间优势):

In [294]: from scipy import sparse                                                                       
In [295]: Vc=sparse.csr_matrix(V)                                                                        
In [296]: Vc.dot(W)                                                                                      
Out[296]: 
array([[46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
        45.54413903, 48.28613399]])
In [297]: V.dot(W)                                                                                       
Out[297]: 
array([46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
       45.54413903, 48.28613399])
In [298]: np.allclose(Vc.dot(W),V@W)                                                                     
Out[298]: True

In [299]: timeit Vc.dot(W)                                                                               
1.48 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

即使稀疏创作:

In [300]: timeit Vm=sparse.csr_matrix(V); Vm.dot(W)                                                      
2.01 ms ± 7.89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM