简体   繁体   中英

Implementing Kruskal's Algorithm in Python using Union Find

I'm trying to implement Kruskal's algorithm in Python using the union-find data structure. My implementation works on the small example, I have developed here, but it has a small problem on the much larger homework graph. Can you help me see what would be wrong with this implementation?

Here is my implementation:

class UnionFind:

    def __init__(self,val,leader):
        self.val = val
        self.leader = leader

    def changeLeader(self,leader):
        self.leader = leader

    def returnLeader(self):
        return self.leader

from collections import defaultdict
def kruskal(graph,edges, N):
    T = dict()
    sizes = defaultdict(lambda: 0)
    edgeWeights = []

    for indx, edge in enumerate(edges):
        n1 = graph[edge][0]
        n2 = graph[edge][1]

#         print("edge is", edge,"nodes",n1.val,n2.val,"leaders",n1.leader,n2.leader)
#         print("state of dict is", T.keys())
        if (n1.leader == n2.leader):
#             print("both nodes part of",n1.leader,'do nothing \n')
            pass

        elif (n1.leader in T.keys()) and (n2.leader not in T.keys()):
#             print("adding ",n2.val, "to group",n1.leader,'\n')
            n2.changeLeader(n1.leader)
            T[n1.leader].append(n2)
            sizes[n1.leader] += 1
            edgeWeights.append(edge)

        elif (n2.leader in T.keys()) and (n1.leader not in T.keys()):
#             print("adding ",n1.val, "to group",n2.leader,'\n')

            n1.changeLeader(n2.leader)
            T[n2.leader].append(n1)
            sizes[n2.leader] += 1
            edgeWeights.append(edge)

        elif (n1.leader in T.keys()) and (n2.leader in T.keys()) and (n1.leader != n2.leader):
#             print("merging groups",n1.leader,n2.leader)
            size1 = sizes[n1.leader]
            size2 = sizes[n2.leader]
            edgeWeights.append(edge)

#             print("sizes are",size1, size2)
            if size1 >= size2:
                for node in T[n2.leader]:
                    if node is not n2:
                        node.changeLeader(n1.leader)
                        T[n1.leader].append(node)
                        sizes[n1.leader] += 1
                        sizes[n2.leader] -= 1

                del T[n2.leader]
                sizes[n2.leader] = 0
                n2.changeLeader(n1.leader)
                T[n1.leader].append(n2)

#                 print("updated list of nodes",T.keys())
#                 for node in T[n1.leader]:
#                     print("includes",node.val)
            else:
                for node in T[n1.leader]:

                    if node is not n1:
                        node.changeLeader(n2.leader)
                        T[n2.leader].append(node)
                        sizes[n2.leader] += 1
                        sizes[n1.leader] -= 1

                del T[n1.leader]
                sizes[n1.leader] = 0
                n1.changeLeader(n2.leader)
                T[n2.leader].append(n1)
        else:
#             print("adding new group",n1.val,n2.val,'\n')
            n2.changeLeader(n1.leader)
            T[n1.leader] = [n1,n2]
            sizes[n1.leader] +=2
            edgeWeights.append(edge)

#         print("updated nodes",graph[edge][0].val,graph[edge][1].val,"leaders",
#               graph[edge][0].leader,graph[edge][1].leader,"\n")
    return T, edgeWeights

Here is the test code:

nodes = [UnionFind("A","A"),UnionFind("B","B"),UnionFind("C","C"),UnionFind("D","D"),UnionFind("E","E")]
graph = {1:[nodes[0],nodes[1]],2:[nodes[3],nodes[4]],
         3:[nodes[0],nodes[4]],4:[nodes[0],nodes[3]],
         5:[nodes[0],nodes[2]],6:[nodes[2], nodes[4]],
         7:[nodes[1],nodes[2]]}
N = 5
edges = list(graph.keys())
edges.sort()

T, weight = kruskal(graph,edges,N)

for node in T['A']:
    print(node.val)

print("edges",weight)

And the resulting output:

edge is 1 nodes A B leaders A B
state of dict is dict_keys([])
adding new group A B 

updated nodes A B leaders A A 

edge is 2 nodes D E leaders D E
state of dict is dict_keys(['A'])
adding new group D E 

updated nodes D E leaders D D 

edge is 3 nodes A E leaders A D
state of dict is dict_keys(['A', 'D'])
merging groups A D
sizes are 2 2
updated list of nodes dict_keys(['A'])
includes A
includes B
includes D
includes E
updated nodes A E leaders A A 

edge is 4 nodes A D leaders A A
state of dict is dict_keys(['A'])
both nodes part of A do nothing 

updated nodes A D leaders A A 

edge is 5 nodes A C leaders A C
state of dict is dict_keys(['A'])
adding  C to group A 

updated nodes A C leaders A A 

edge is 6 nodes C E leaders A A
state of dict is dict_keys(['A'])
both nodes part of A do nothing 

updated nodes C E leaders A A 

edge is 7 nodes B C leaders A A
state of dict is dict_keys(['A'])
both nodes part of A do nothing 

updated nodes B C leaders A A 

A
B
D
E
C
edges [1, 2, 3, 5]

So the code should end with all nodes in the graph having a single parent. At least this is my understanding of Kruskal's algorithm. It does not on the larger graph, but I cant post this example here. Any ideas based on this code would be very appreciated.

"So the code should end with all nodes in the graph having a single parent."

No! The code should end with all nodes in the graph belonging to a single connected component, but this does not mean they all have the same parent in your union-find data structure. The data structure defines that two nodes are in the same connected component if they have the same root node , but they might not have the same parent.

To correct your UnionFind class implementation, we need to make the returnLeader method do a search for the root node, instead of just returning the parent:

    def returnLeader(self):
        cur = self
        while cur != cur.leader:
            cur = cur.leader
        return cur

This is now logically correct, but we can improve the efficiency for large inputs by doing "path compression". To save doing the same search many times, update the leader whenever the search finds a different root node. If we call returnLeader recursively then it will update all the nodes along the path to the root node, too.

    def returnLeader(self):
        if self.leader != self.leader.leader:
            self.leader = self.leader.returnLeader()
        return self.leader

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM