[英]Speed up left matmul for scipy.sparse.csr_matrix
我需要執行以下矩陣乘法: x * A[idx]
其中A
是scipy.sparse.csr_matrix
, idx
是np.array
索引。 由於索引,我無法將其更改為csc_matrix
。 它似乎比右矩陣乘法A[idx] * x
慢 50 倍,並且僅比左矩陣乘法(在整個索引上) u * A
快一點,即使len(idx) << A.shape[0]
。 我怎樣才能加快速度?
由於我沒有在其他地方找到解決方案,這就是我最終要做的。
使用numba
,我寫道:
@njit(nogil=True)
def fast_csr_vm(x, data, indptr, indices, d, idx):
"""
Returns the vector matrix product x * M[idx]. M is described
in the csr format.
Returns x * M[idx]
x: 1-d iterable
data: data field of a scipy.sparse.csr_matrix
indptr: indptr field of a scipy.sparse.csr_matrix
indices: indices field of a scipy.sparse.csr_matrix
d: output dimension
idx: 1-d iterable: index of the sparse.csr_matrix
"""
res = np.zeros(d)
assert x.shape[0] == len(idx)
for k, i in np.ndenumerate(idx):
for j in range(indptr[i], indptr[i+1]):
j_idx = indices[j]
res[j_idx] += x[k] * data[j]
return res
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.