简体   繁体   中英

Merge rows of 2d array if at least one value is equal

I am trying to achieve the following. Considering a matrix

[[ 9 10]
 [10  9]
 [11 17]
 [11 18]
 [12 13]
 [13 12]
 [17 11]
 [17 18]
 [18 11]
 [18 17]]

I want to "merge" all rows, that have at least one similar value. For this example, I want to get

[[9,10]
[11, 17, 18]
[12, 13]]

I know that numpy works with arrays of fixed shape. Therefore, I try to fill another array of nans with these values. A simple approach would be a for loop, where I loop over every row, check if the result array already has one of the values and extend if so, if not put into the next free row. I did this without numpy, using a list of lists to put the groups in.

groups = []
for pair in matrix:
    pair = [pair[0], pair[1]]
    append_pair = True                
    for sublist in groups:
        if pair[0] in sublist or pair[1] in sublist:
            sublist.extend(x for x in pair if x not in sublist)
            append_pair = False
    if append_pair is True:
        groups.append(pair)

Is there a better numpy way to do it?

Here is an optimized approach but a little hacky:

In [14]: def find_intersection(m_list):
            for i,v in enumerate(m_list) : 
                for j,k in enumerate(m_list[i+1:],i+1):
                       if np.in1d(v, k).any():
                              m_list[i] = np.union1d(v, m_list.pop(j))
                              return find_intersection(m_list)
            return m_list
   ....:         

In [15]: find_intersection(a)
Out[15]: [array([ 9, 10]), array([11, 17, 18]), array([12, 13])]

This is a connect-component finding problem if you think of each sublist as an edge connecting two nodes. If you have networkx installed, you can also use networkx package's connected_components function to solve this problem.

import networkx as nx

G = nx.from_edgelist([[ 9, 10],
                     [10,  9],
                     [11, 17],
                     [11, 18],
                     [12, 13],
                     [13, 12],
                     [17, 11],
                     [17, 18],
                     [18, 11],
                     [18, 17]])
list(nx.connected_components(G))
[{9, 10}, {11, 17, 18}, {12, 13}]

This is a numpy problem. So... a Scipy solution would be using scipy.sparse.csgraph.connected_components . But its usage is not as well documented. I would suggest you to use networkx here.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM