简体   繁体   中英

Merging tuples if they have one common element

Consider the following list:

tuple_list = [('c', 'e'), ('c', 'd'), ('a', 'b'), ('d', 'e')]

How can I achieve this?

new_tuple_list = [('c', 'e', 'd'), ('a', 'b')]

I have tried:

for tuple in tuple_list:
    for tup in tuple_list:
        if tuple[0] == tup[0]:
            new_tup = (tuple[0],tuple[1],tup[1])
            new_tuple_list.append(new_tup)

But it only works if I have the elements of the tuple in a certain order which means it will result in this instead:

new_tuple_list = [('c', 'e', 'd'), ('a', 'b'), ('d', 'e')]

You could consider the tuples as edges in a graph and your goal as finding connected components within the graph. Then you could simply loop over vertices (items in tuples) and for each vertex you haven't visited yet execute DFS to generate a component:

from collections import defaultdict

def dfs(adj_list, visited, vertex, result, key):
    visited.add(vertex)
    result[key].append(vertex)
    for neighbor in adj_list[vertex]:
        if neighbor not in visited:
            dfs(adj_list, visited, neighbor, result, key)

edges = [('c', 'e'), ('c', 'd'), ('a', 'b'), ('d', 'e')]

adj_list = defaultdict(list)
for x, y in edges:
    adj_list[x].append(y)
    adj_list[y].append(x)

result = defaultdict(list)
visited = set()
for vertex in adj_list:
    if vertex not in visited:
        dfs(adj_list, visited, vertex, result, vertex)

print(result.values())

Output:

[['a', 'b'], ['c', 'e', 'd']]

Note that in above both the components and elements within a component are in random order.

If you don't need duplicate values (the ability to preserve ['a', 'a', 'b'] , for example), this is a simple and fast way to do what you want via sets:

iset = set([frozenset(s) for s in tuple_list])  # Convert to a set of sets
result = []
while(iset):                  # While there are sets left to process:
    nset = set(iset.pop())      # Pop a new set
    check = len(iset)           # Does iset contain more sets
    while check:                # Until no more sets to check:
        check = False
        for s in iset.copy():       # For each other set:
            if nset.intersection(s):  # if they intersect:
                check = True            # Must recheck previous sets
                iset.remove(s)          # Remove it from remaining sets
                nset.update(s)          # Add it to the current set
    result.append(tuple(nset))  # Convert back to a list of tuples

gives

[('c', 'e', 'd'), ('a', 'b')]

The task becomes trivial with NetworkX , library for graphs manipulation. Similar to this answer by @niemmi you'd need to find the connected components :

import networkx as nx

tuple_list = [('c', 'e'), ('c', 'd'), ('a', 'b'), ('d', 'e')]
graph = nx.Graph(tuple_list)
result = list(nx.connected_components(graph))
print(result)
# [{'e', 'c', 'd'}, {'b', 'a'}]

To get the result as a list of tuples:

result = list(map(tuple, nx.connected_components(G)))
print(result)
# [('d', 'e', 'c'), ('a', 'b')]

This has a bad performance because list-contains checks are O(n) but it's quite short:

result = []

for tup in tuple_list:
    for idx, already in enumerate(result):
        # check if any items are equal
        if any(item in already for item in tup):
            # tuples are immutable so we need to set the result item directly
            result[idx] = already + tuple(item for item in tup if item not in already)
            break
    else:
        # else in for-loops are executed only if the loop wasn't terminated by break
        result.append(tup)

This has the nice side-effect that the order is kept:

>>> result
[('c', 'e', 'd'), ('a', 'b')]

I had that problem with sets so I'm contributing my solution to this. It combines sets with one of more common element as long as possible.

My example data:

data = [['A','B','C'],['B','C','D'],['D'],['X'],['X','Y'],['Y','Z'],['M','N','O'],['M','N','O'],['O','A']]
data = list(map(set,data))

My code to solve the problem:

oldlen = len(data)+1
while len(data)<oldlen:
    oldlen = len(data)
    for i in range(len(data)):
        for j in range(i+1,len(data)):
                if len(data[i]&data[j]):
                    data[i] = data[i]|data[j]
                    data[j] = set()
    data = [data[i] for i in range(len(data)) if data[i]!= set()]

Result:

[{'A', 'B', 'C', 'D', 'M', 'N', 'O'}, {'X', 'Y', 'Z'}]

Use sets. You are checking for overlap and accumulation of (initially small) sets, and Python has a data type for that:

#!python3

#tuple_list = [('c', 'e'), ('c', 'd'), ('a', 'b'), ('d', 'e')]
tuple_list = [(1,2), (3,4), (5,), (1,3,5), (3,'a'),
        (9,8), (7,6), (5,4), (9,'b'), (9,7,4),
        ('c', 'e'), ('e', 'f'), ('d', 'e'), ('d', 'f'),
        ('a', 'b'),
        ]
set_list = []

print("Tuple list:", tuple_list)
for t in tuple_list:
    #print("Set list:", set_list)
    tset = set(t)
    matched = []
    for s in set_list:
        if tset & s:
            s |= tset
            matched.append(s)

    if not matched:
        #print("No matches. New set: ", tset)
        set_list.append(tset)

    elif len(matched) > 1:
        #print("Multiple Matches: ", matched)
        for i,iset in enumerate(matched):
            if not iset:
                continue
            for jset in matched[i+1:]:
                if iset & jset:
                    iset |= jset
                    jset.clear()

set_list = [s for s in set_list if s]
print('\n'.join([str(s) for s in set_list]))

I bumped into this problem when resolving coreferences, I need to merge sets in a list of sets that have common elements:

import copy

def merge(list_of_sets):
    # init states
    list_of_sets = copy.deepcopy(list_of_sets)
    result = []
    indices = find_fist_overlapping_sets(list_of_sets)
    while indices:
        # Keep other sets
        result = [
            s
            for idx, s in enumerate(list_of_sets)
            if idx not in indices
        ]
        # Append merged set
        result.append(
            list_of_sets[indices[0]].union(list_of_sets[indices[1]])
        )
        # Update states
        list_of_sets = result
        indices = find_fist_overlapping_sets(list_of_sets)
    return list_of_sets

def find_fist_overlapping_sets(list_of_sets):
    for i, i_set in enumerate(list_of_sets):
        for j, j_set in enumerate(list_of_sets[i+1:]):
            if i_set.intersection(j_set):
                return i, i+j+1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM