簡體   English   中英

Scipy csr_matrix無法正確復制

[英]Scipy csr_matrix does not copy correctly

我在從一個csr_matrix計算並從結果創建新對象時遇到了一些問題。 在嘗試查找它時,我只是做了一些簡單的代碼來復制原始矩陣,而副本並不相同。 我已經在很小的矩陣(如文檔中給出的)上進行了嘗試,但是在現實世界的矩陣(大約250萬個條目,所有這些都不為零)上,結果很奇怪。 這是測試代碼:

print type(X_ngrams)
tst = csr_matrix( (X_ngrams.data,X_ngrams.nonzero()))
print "Original:"
print "shape     ", X_ngrams.shape
r1,c1=X_ngrams.nonzero()
print "rows, cols", r1[:10],c1[:10]
print "indptr    ", X_ngrams.indptr[:10]
print "indices   ", X_ngrams.indices[:10]
print "data[:10] ", X_ngrams.data[:10]
#
print
print "Copy:"
print "shape     ", tst.shape
r2,c2=tst.nonzero()
print "rows, cols", r2[:10],c2[:10]
print "indptr    ", tst.indptr[:10]
print "indices   ", tst.indices[:10]
print "data[:10] ", tst.data[:10]

結果如下:

<class 'scipy.sparse.csr.csr_matrix'>
Original:
shape      (2257, 202262)
rows, cols [0 0 0 0 0 0 0 0 0 0] [ 69627  70494 168418 174006 157892     161787 146945 148354  51951  53422]
indptr     [   0  518 1247 3156 3634 4368 5594 6670 8540 9257]
indices    [ 69627  70494 168418 174006 157892 161787 146945 148354  51951  53422]
data[:10]  [ 2  1 23  1 35  1 11  1  8  1]

Copy:
shape      (2257, 202262)
rows, cols [0 0 0 0 0 0 0 0 0 0] [1439 2461 2561 2683 2748 4279 6212 6275 6332 6611]
indptr     [   0  518 1247 3156 3634 4368 5594 6670 8540 9257]
indices    [1439 2461 2561 2683 2748 4279 6212 6275 6332 6611]
data[:10]  [20  1  1  1  1  1  1  1  1  1]

為什么副本的結構不同? 我需要創建的矩陣應該具有完全相同的結構,每個位置的數字都不同。

我無法使用您提供的數據來復制您的問題,但是我懷疑問題出在副本被排序時X_ngrams沒有被排序。 排序是由nonzero執行的。

比較兩個indices 兩者都是第一行中500多個值的一小部分:

indices    [ 69627  70494 168418 174006 157892 161787 146945 148354  51951  53422]
indices    [1439 2461 2561 2683 2748 4279 6212 6275 6332 6611]

第二個列表較小,並且已排序。 X_ngrams.has_sorted_indices的值是X_ngrams.has_sorted_indices

您真正需要比較的是兩者都不nonzero

一種解決方案是先對X_ngrams排序

 X._ngrams.sort_indices()  # sort in place

您可能還考慮使用M.copy()M.tocsr(copy=True) M.sorted_indices()返回帶有排序索引的副本。

此格式:

sparse.csr_matrix((M.data, M.indices, M.indptr))

使用相同的數組M進行復制。 或者,如果您希望它們成為副本:

sparse.csr_matrix((M.data.copy(), M.indices.copy(), M.indptr.copy()))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM