簡體   English   中英

如何獲得稀疏矩陣數據數組的對角元素的索引

[英]How to get indices of diagonal elements of a sparse matrix data array

我有csr格式的稀疏矩陣,例如:

>>> a = sp.random(3, 3, 0.6, format='csr')  # an example
>>> a.toarray()  # just to see how it looks like
array([[0.31975333, 0.88437035, 0.        ],
       [0.        , 0.        , 0.        ],
       [0.14013856, 0.56245834, 0.62107962]])
>>> a.data  # data array
array([0.31975333, 0.88437035, 0.14013856, 0.56245834, 0.62107962])

對於此特定示例,我想獲得[0, 4] ,它們是非零對角元素0.319753330.62107962的數據數組索引。

一種簡單的方法如下:

ind = []
seen = set()
for i, val in enumerate(a.data):
    if val in a.diagonal() and val not in seen:
        ind.append(i)
        seen.add(val)

但實際上,矩陣很大,因此我不想使用for循環或使用toarray()方法轉換為numpy數組。 有更有效的方法嗎?

編輯 :我剛剛意識到,當存在非對角線元素等於和位於一些對角線元素之前的情況下,以上代碼給出了錯誤的結果:它返回該非對角線元素的索引。 同樣,它不返回重復對角元素的索引。 例如:

a = np.array([[0.31975333, 0.88437035, 0.        ],
              [0.62107962, 0.31975333, 0.        ],
              [0.14013856, 0.56245834, 0.62107962]])
a = sp.csr_matrix(a)

>>> a.data
array([0.31975333, 0.88437035, 0.62107962, 0.31975333, 0.14013856,
       0.56245834, 0.62107962])

我的代碼返回ind = [0, 2] ,但應為[0, 3, 6] Andras Deak提供的代碼(他的get_rowwise函數)返回正確的結果。

我發現了一個可能更有效的解決方案,盡管它仍在循環。 但是,它在矩陣的行上而不是元素本身上循環。 根據矩陣的稀疏模式,此速度可能會更快,也可能不會更快。 對於具有N行的稀疏矩陣,這可以保證花費N次迭代。

我們只遍歷每一行,通過a.indicesa.indptr獲取填充的列索引,如果給定行的對角線元素出現在填充值中,則我們計算其索引:

import numpy as np
import scipy.sparse as sp

def orig_loopy(a):
    ind = []
    seen = set()
    for i, val in enumerate(a.data):
        if val in a.diagonal() and val not in seen:
            ind.append(i)
            seen.add(val)
    return ind

def get_rowwise(a):
    datainds = []
    indices = a.indices # column indices of filled values
    indptr = a.indptr   # auxiliary "pointer" to data indices
    for irow in range(a.shape[0]):
        rowinds = indices[indptr[irow]:indptr[irow+1]] # column indices of the row
        if irow in rowinds:
            # then we've got a diagonal in this row
            # so let's find its index
            datainds.append(indptr[irow] + np.flatnonzero(irow == rowinds)[0])
    return datainds

a = sp.random(300, 300, 0.6, format='csr')
orig_loopy(a) == get_rowwise(a) # True

對於具有相同密度的(300,300)形隨機輸入,原始版本在3.7秒內運行,新版本在5.5毫秒內運行。

方法1

這是一種矢量化方法,該方法首先生成所有非零索引,然后獲取行索引和列索引相同的位置。 這有點慢,並且內存使用率很高。

import numpy as np
import scipy.sparse as sp
import numba as nb

def get_diag_ind_vec(csr_array):
  inds=csr_array.nonzero()
  return np.array(np.where(inds[0]==inds[1])[0])

方法二

只要使用Compiler,例如,循環方法通常就性能而言不會有問題。 NumbaCython 我為可能發生的最大對角元素分配了內存。 如果此方法占用大量內存,則可以輕松對其進行修改。

@nb.jit()
def get_diag_ind(csr_array):
    ind=np.empty(csr_array.shape[0],dtype=np.uint64)
    rowPtr=csr_array.indptr
    colInd=csr_array.indices

    ii=0
    for i in range(rowPtr.shape[0]-1):
      for j in range(rowPtr[i],rowPtr[i+1]):
        if (i==colInd[j]):
          ind[ii]=j
          ii+=1

    return ind[:ii]

計時

csr_array = sp.random(1000, 1000, 0.5, format='csr')

get_diag_ind_vec(csr_array)   -> 8.25ms
get_diag_ind(csr_array)       -> 0.65ms (first call excluded)

這是我的解決方案,它似乎比get_rowwise (Andras Deak)和get_diag_ind_vec (max9111)快(我不考慮使用Numba或Cython)。

想法是將矩陣(或其副本)的非零對角元素設置為不在原始矩陣中的某個唯一值x (我選擇了最大值+ 1),然后簡單地使用np.where(a.data == x)以返回所需的索引。

def diag_ind(a):
    a = a.copy()
    i = a.diagonal() != 0  
    x = np.max(a.data) + 1
    a[i, i] = x
    return np.where(a.data == x)

定時:

A = sp.random(1000, 1000, 0.5, format='csr')

>>> %timeit diag_ind(A)
6.32 ms ± 335 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

>>> %timeit get_diag_ind_vec(A)
14.6 ms ± 292 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

>>> %timeit get_rowwise(A)
24.3 ms ± 5.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

編輯:復制稀疏矩陣(以保留原始矩陣)的存儲效率不高,因此更好的解決方案是存儲對角線元素,然后將其用於恢復原始矩陣。

def diag_ind2(a):
    a_diag = a.diagonal()
    i = a_diag != 0  
    x = np.max(a.data) + 1
    a[i, i] = x
    ind = np.where(a.data == x)
    a[i, i] = a_diag[np.nonzero(a_diag)]
    return ind

這甚至更快:

>>> %timeit diag_ind2(A)
2.83 ms ± 419 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM