簡體   English   中英

使用`coo`矩陣和python中的numpy數組加快for循環的操作

[英]Speed up operations on a for loop with `coo` matrix and a numpy array in python

我有一個numpy數組和一個coo矩陣。 我需要基於coo矩陣中的元素更新numpy數組。 numpy數組和矩陣都非常大,如下所示:

 graph_array = [[  1.0   1.0   5.0  9.0]
 [  2.0   5.0   6.0   5.0]
 [  3.0   5.0   7.0   6.0]]

matrix_coo = (1, 5) 0.5
(2, 8)  0.4
(5, 7)  0.8

我需要做的如下:

如果在陣列中的每個列表中的第二和第三元件即list_graph[i][1][2]其可以是1,55,65,7 )等於在一個行和列的一對coo矩陣(1, 5), (2, 8) or (5, 7)則與該對關聯的值(對於(1, 5)等於0.5 )必須替換數組中列表中的第四個元素。

我的預期輸出將是:

output_array = [[  1.0   1.0   5.0  0.5]
[  2.0   5.0   6.0   5.0]
[  3.0   5.0   7.0   0.8]]

我正在使用的當前代碼如下:

 row_idx = list(matrix_coo.row)
 col_idx = list(matrix_coo.col)
 data_idx = list(matrix_coo.data)

x = 0
    while x < len(row_cost_idx):
        for m in graph_array:
            if m[1] == row_idx[x]:
                if m[2] == col_idx[x]:
                    m[3] = data_idx[x]
        x += 1

它的確為我提供了正確的輸出,但是由於該數組有21596個項目,而矩陣有21596行,因此需要很長時間。

有更快的方法嗎?

您的迭代是純Python列表操作。 這一事實row_idx最初是作為一的屬性coo_matrix不適用

可以用以下方法清除它:

什么是row_cost_idx 如果與row_idx相同,我們可以做

for r,c,d in zip(matrix_coo.row, matrix_coo.col, matrix_coo.data):
    for m in graph_array: # not list_graph?
        if m[:2]==[r,c]:
            m[3] = d

我認為迭代是相同的,但尚未對其進行測試。 我也不知道速度。

matrix_coo非零元素和graph_array子列表上的兩次迭代注定會很慢,這僅僅是因為您要進行很多次迭代。

如果graph_array是一個numpy array ,我們可以一次測試所有行,例如

mask = (graph_array[:, :2]==[r,c]).all(axis=1)
graph_array[mask,3] = d

其中對具有正確索引的graph_array行, mask將為1。 (同樣,這未經測試)

為了提高速度,我將graph_arraymatrix_coograph_array為2d numpy(密集)數組,並查看是否可以通過一些數組操作解決問題。 從中得出的見解可能會幫助我替換matrix_coo迭代。

========================

經過測試的代碼

import numpy as np
from scipy import sparse

graph_array = np.array([[  1.0,   1.0,   5.0 , 9.0],
 [  2.0,   5.0 ,  6.0  , 5.0],
 [  3.0  , 5.0 ,  7.0 ,  6.0]])

r,c,d = [1,2,5], [5,8,7],[0.5,0.4,0.8]
matrix_coo = sparse.coo_matrix((d,(r,c)))

def org(graph_array, matrix_coo):
    row_idx = list(matrix_coo.row)
    col_idx = list(matrix_coo.col)
    data_idx = list(matrix_coo.data)

    x = 0
    while x < len(row_idx):
        for m in graph_array:
            if m[1] == row_idx[x]:
                if m[2] == col_idx[x]:
                    m[3] = data_idx[x]
        x += 1
    return graph_array

new_array = org(graph_array.copy(), matrix_coo)    
print(graph_array)
print(new_array)

def alt(graph_array, matrix_coo):
    for r,c,d in zip(matrix_coo.row, matrix_coo.col, matrix_coo.data):
        for m in graph_array: 
            if (m[[1,2]]==[r,c]).all():  # array test
                m[3] = d
    return graph_array

new_array = alt(graph_array.copy(), matrix_coo)    
print(new_array)

def altlist(graph_array, matrix_coo):
    for r,c,d in zip(matrix_coo.row, matrix_coo.col, matrix_coo.data):
        for m in graph_array:
            if (m[1:3]==[r,c]):   # list test
                m[3] = d
    return graph_array

new_array = altlist(graph_array.tolist(), matrix_coo)    
print(new_array)

def altarr(graph_array, matrix_coo):
    for r,c,d in zip(matrix_coo.row, matrix_coo.col, matrix_coo.data):
        mask = (graph_array[:, 1:3]==[r,c]).all(axis=1)
        graph_array[mask,3] = d
    return graph_array

new_array = alt(graph_array.copy(), matrix_coo)    
print(new_array)

0909:~/mypy$ python3 stack3727173.py 
[[ 1.  1.  5.  9.]
 [ 2.  5.  6.  5.]
 [ 3.  5.  7.  6.]]
[[ 1.   1.   5.   0.5]
 [ 2.   5.   6.   5. ]
 [ 3.   5.   7.   0.8]]
[[ 1.   1.   5.   0.5]
 [ 2.   5.   6.   5. ]
 [ 3.   5.   7.   0.8]]
[[1.0, 1.0, 5.0, 0.5], [2.0, 5.0, 6.0, 5.0], [3.0, 5.0, 7.0, 0.80000000000000004]]
[[ 1.   1.   5.   0.5]
 [ 2.   5.   6.   5. ]
 [ 3.   5.   7.   0.8]]

對於這個小例子,您的功能最快。 它也適用於列表和數組。 對於小型的物料清單操作,通常比數組操作更快。 因此,使用數組運算僅比較兩個數字並沒有改善。

復制graph_array 1000倍的altarr版本比您的代碼快10倍。 它正在最大范圍內執行陣列操作。 我沒有嘗試增加matrix_coo的大小。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM