简体   繁体   中英

Generating an optimal binary search tree (Cormen)

I'm reading Cormen et al., Introduction to Algorithms (3rd ed.) ( PDF ), section 15.4 on optimal binary search trees, but am having some trouble implementing the pseudocode for the optimal_bst function in Python.

Here is the example I'm trying to apply the optimal BST to:

在此处输入图片说明

Let us define e[i,j] as the expected cost of searching an optimal binary search tree containing the keys labeled from i to j . Ultimately, we wish to compute e[1, n] , where n is the number of keys (5 in this example). The final recursive formulation is:

在此处输入图片说明

which should be implemented by the following pseudocode:

在此处输入图片说明

Notice that the pseudocode interchangeably uses 1- and 0-based indexing, whereas Python uses only the latter. As a consequence I'm having trouble implementing the pseudocode. Here is what I have so far:

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
    for i in range(n-l+1):
        j = i + l
        e[i, j] = np.inf
        w[i, j] = w[i, j-1] + p[j-1] + q[j]
        for r in range(i, j+1):
            t = e[i-1, r-1] + e[r, j] + w[i-1, j]
            if t < e[i-1, j]:
                e[i-1, j] = t
                root[i-1, j] = r

print(w)
print(e)

However, if I run this the weights w get computed correctly, but the expected search values e remain 'stuck' at their initialized values:

[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
 [ 0.    0.1   0.25  0.35  0.5   0.8 ]
 [ 0.    0.    0.05  0.15  0.3   0.6 ]
 [ 0.    0.    0.    0.05  0.2   0.5 ]
 [ 0.    0.    0.    0.    0.05  0.35]
 [ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
 [ 0.    0.1    inf   inf   inf   inf]
 [ 0.    0.    0.05   inf   inf   inf]
 [ 0.    0.    0.    0.05   inf   inf]
 [ 0.    0.    0.    0.    0.05   inf]
 [ 0.    0.    0.    0.    0.    0.1 ]]

What I expect is that e , w , and root be as follows:

在此处输入图片说明

I've been debugging this for a couple of hours by now and am still stuck. Can someone point out what is wrong with the Python code above?

It appears to me that you made a mistake in the indices. I couldn't make it work as expected but the following code should give you an indication where I was heading at (there is probably an off by one somewhere):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
    return m[i - 1, j - 1]


def set2(m, i, j, v):
    m[i - 1, j - 1] = v


def get1(m, i):
    return m[i - 1]


def set1(m, i, v):
    m[i - 1] = v


e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
    for i in range(n - l + 2):
        j = i + l - 1
        set2(e, i, j, np.inf)
        set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
        for r in range(i, j + 1):
            t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
            if t < get2(e, i, j):
                set2(e, i, j, t)
                set2(root, i, j, r)

print(w)
print(e)

The result:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
 [ 0.    0.2   0.3   0.45  0.7   0.  ]
 [ 0.    0.    0.1   0.25  0.5   0.  ]
 [ 0.    0.    0.    0.15  0.4   0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
 [ 0.    0.2   0.4   0.8   1.35  0.  ]
 [ 0.    0.    0.1   0.35  0.85  0.  ]
 [ 0.    0.    0.    0.15  0.55  0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.7   1.2   1.5   2.    0.    0.3 ]]

In the end I used pandas ' Series and DataFrame objects initialized with custom index and columns to coerce the arrays to have the same indexing as in the pseudocode. After that, the pseudocode can be almost copy-pasted:

import numpy as np
import pandas as pd

P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)

p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)

e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))

for l in range(1, n+1):
    for i in range(1, n-l+2):
        j = i+l-1
        e.set_value(i, j, np.inf)
        w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
        for r in range(i, j+1):
            t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
            if t < e.get_value(i, j):
                e.set_value(i, j, t)
                root.set_value(i, j, r)

print(e)
print(w)
print(root)

which yields the expected results:

      0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
      0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
     1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

I would still be interested in a solution with Numpy arrays, though, as this seems more elegant to me.

Kurt, Thanks for your post! Yours is the only working implementation of this problem I found. I spent mucho tiempo wrestling with the indices. Here is my implementation with numpy arrays.

import numpy as np
import math

def optimalBST(p,q,n):

    e = np.zeros((n+1)**2).reshape(n+1,n+1)
    w = np.zeros((n+1)**2).reshape(n+1,n+1)
    root = np.zeros((n+1)**2).reshape(n+1,n+1)

    # Initialization
    for i in range(n+1):
        e[i,i] = q[i]
        w[i,i] = q[i]
    for i in range(0,n):
        root[i,i] = i+1

    for l in range(1,n+1):
        for i in range(0, n-l+1):
            j = i+l
            min_ = math.inf
            w[i,j] = w[i,j-1] + p[j] + q[j]
            for r in range(i,j):
                t = e[i, r-1+1] + e[r+1,j] +  w[i,j]
                if t < min_:
                    min_ = t                
                    e[i, j] = t
                    root[i, j-1] = r+1

    root_pruned = np.delete(np.delete(root, n, 1), n, 0)        # Trim last col & row.

    print("------ e -------")
    print(e)
    print("------ w -------")
    print(w)
    print("----- root -----")
    print(root_pruned)

def main():

    p = [0,.15,.1,.05,.1,.2]
    q = [.05,.1,.05,.05,.05,.1]
    n = len(p)-1

    optimalBST(p,q,n)

if __name__ == '__main__':
    main()

Output:

------ e -------
[[0.05 0.45 0.9  1.25 1.75 2.75]
 [0.   0.1  0.4  0.7  1.2  2.  ]
 [0.   0.   0.05 0.25 0.6  1.3 ]
 [0.   0.   0.   0.05 0.3  0.9 ]
 [0.   0.   0.   0.   0.05 0.5 ]
 [0.   0.   0.   0.   0.   0.1 ]]
------ w -------
[[0.05 0.3  0.45 0.55 0.7  1.  ]
 [0.   0.1  0.25 0.35 0.5  0.8 ]
 [0.   0.   0.05 0.15 0.3  0.6 ]
 [0.   0.   0.   0.05 0.2  0.5 ]
 [0.   0.   0.   0.   0.05 0.35]
 [0.   0.   0.   0.   0.   0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
 [0. 2. 2. 2. 4.]
 [0. 0. 3. 4. 5.]
 [0. 0. 0. 4. 5.]
 [0. 0. 0. 0. 5.]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM