简体   繁体   中英

Prim's MST implementation with priority queue error

My algorithm calculated the correct distances between points, but then updated some of the values to a smaller value in the nested for loop for the test case below.

I suspect there's an error in the implementation of my nested for loop?

5
0 0
0 2
1 1
3 0
3 2

The answer should be 7.064495102 (i got 7.650281540)

def minimum_distance(x, y):
    result = 0.
    distance = {}
    for i in range(len(x)):
        j=i+1
        while j<len(x):
            distance[i, j] = calc_distance(x[i], y[i], x[j],y[j])
            # distance.append([i, j, calc_distance(x[i], y[i], x[j],y[j])])
            j+=1
            
    cost = []
    parent = []
    for k in range(len(x)):
        cost.append([float('inf'), k])
        parent.append([None, k]) 
    
    # [cost, vertex]
    cost[0] = [0,0]
    parent[0] = [-1,0]

    pq = min_heap(cost)
    cost_pq = pq.heapify()
    
    while len(cost_pq) != 0:
        v, cost_pq = pq.extract_min(cost_pq)
        min_cost, min_current_node = v[0], v[1]
        result += min_cost
        for edge in distance:
            for vertex in cost_pq:
                # check if cost_pq contains vertex edge[1]
                if vertex[1] == edge[1]:
                    vertex_index = cost_pq.index(vertex)
                    if cost_pq[vertex_index][0] > distance[edge[0], edge[1]]:
                        cost_pq[vertex_index][0] = distance[edge[0], edge[1]]
                        parent[edge[1]][0] = edge[0]
                        pq.heapify() 
            
    return result

To illustrate the comment, this is what should look like:

from itertools import combinations
import heapq

def distance(x1, y1, x2, y2):
    return ((x1-x2)**2 + (y1-y2)**2)**0.5

def MST(xs, ys):
    n = len(xs)
    # heap of tuples (distance, node1, node2)
    q = [(distance(xs[i], ys[i], xs[j], ys[j]), i, j) 
         for i, j in combinations(range(n), 2)]
    heapq.heapify(q)
    # each node is its own parent in the beginning
    parents = list(range(n))
    # naive parent implementation - should include depth to balance for better performance
    def parent(i):
        while parents[i] != i:
            i = parents[i]
        return i

    total_distance = 0
    while q:
        d, i, j = heapq.heappop(q)
        parent1, parent2 = parent(i), parent(j)
        if parent1 != parent2:
            total_distance += d
            parents[parent1] = j
    return total_distance

MST([0, 0, 1, 3, 3], [0, 2, 1, 0, 2]) returns 7.06449510224598

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM