簡體   English   中英

從分隔的字符串創建scipy.sparse.csr_matrix

[英]create a scipy.sparse.csr_matrix from a delimited string

我正在處理大量的二進制數據,這些數據以\\t\\t1\\t\\t1\\t\\t\\t (但更長)的形式逐行進入程序。 可以想象,這些是制表符分隔文件中的行。

顯然,我可以執行'\\t\\t1\\t\\t1\\t\\t\\t'.split('\\t')並獲得1''的列表,我可以很容易地將其變成1和0或T / F等等。 但是,數據非常稀疏(很多0而不是很多1),因此我希望使用某種稀疏表示形式。

我的問題是:沒有人知道直接從此字符串轉換為scipy.sparse.csr_matrix() 而不 scipy.sparse.csr_matrix()創建中間密集矩陣的方法嗎?

我嘗試將拆分字符串(即1''的列表)直接傳遞給csr_matrix() ,但遇到TypeError: no supported conversion for types: (dtype('<U1'),)

就像我說的那樣,我可以執行上述操作,並得到1和0,然后將轉換為csr_matrix()但是由於我始終在創建完全密集的版本,因此我失去了稀疏的所有速度和內存優勢。

scipy無法解釋您的輸入,因為它不知道您希望將空字符串轉換為0。這很好:

>>> from scipy.sparse import csr_matrix
>>> x = [0 if not a else int(a) for a in "\t\t\t\t1\t\t\t1\t\t\t".split('\t')] 
>>> csr_matrix(x)
<1x11 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in Compressed Sparse Row format>

在矩陣化之前,請確保您的列表全部為numbrt格式。

在我回憶起OP的評論之后,您可以強制將空字符串轉換為0, 因此更好的解決方案

>>> csr_matrix("\t\t\t\t1\t\t\t1\t\t\t".split('\t'),dtype=np.int64)
<1x11 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in Compressed Sparse Row format>

少生成一個列表。

這是一種逐行處理數據的方法:

In [32]: astr = '\t\t1\t\t1\t\t\t'      # sample row
In [33]: row, col = [],[]
In [34]: for i in range(5):
    ...:     c = [j for j,v in enumerate(astr.split('\t')) if v]
    ...:     row.extend([i]*len(c))
    ...:     col.extend(c)
    ...: data = np.ones(len(col),'int32')
    ...: M = sparse.csr_matrix((data, (row, col)))
    ...: 
In [35]: M
Out[35]: 
<5x5 sparse matrix of type '<class 'numpy.int32'>'
    with 10 stored elements in Compressed Sparse Row format>
In [36]: M.A
Out[36]: 
array([[0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1]], dtype=int32)

對於每一行,我僅收集“ 1”的索引。 通過這些,我構建了相應的datarow列表(或數組)。 從理論上講,我可以構造indptr來進行更直接的csr創建,但是coo風格更容易理解。

中間值為:

In [40]: c
Out[40]: [2, 4]
In [41]: row
Out[41]: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
In [42]: col
Out[42]: [2, 4, 2, 4, 2, 4, 2, 4, 2, 4]
In [43]: data
Out[43]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

獲取c值的另一種方法是:

In [46]: np.where(astr.split('\t'))[0]
Out[46]: array([2, 4])

(但列表理解速度更快)。

字符串和列表find/index方法找到第一項,但不是全部。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM