繁体   English   中英

生成最优二叉搜索树(Cormen)

[英]Generating an optimal binary search tree (Cormen)

我正在阅读 Cormen 等人的算法简介(第 3 版)( PDF ),关于最佳二叉搜索树的第 15.4 节,但是在 Python 中实现optimal_bst函数的伪代码时遇到了一些问题。

这是我尝试将最佳 BST 应用于的示例:

在此处输入图片说明

让我们将e[i,j]定义为搜索包含从ij标记的键的最优二叉搜索树的预期成本。 最终,我们希望计算e[1, n] ,其中n是键的数量(在本例中为 5)。 最终的递归公式是:

在此处输入图片说明

这应该由以下伪代码实现:

在此处输入图片说明

请注意,伪代码可互换地使用基于 1 和基于 0 的索引,而 Python 仅使用后者。 因此,我在实现伪代码时遇到了麻烦。 这是我到目前为止所拥有的:

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
    for i in range(n-l+1):
        j = i + l
        e[i, j] = np.inf
        w[i, j] = w[i, j-1] + p[j-1] + q[j]
        for r in range(i, j+1):
            t = e[i-1, r-1] + e[r, j] + w[i-1, j]
            if t < e[i-1, j]:
                e[i-1, j] = t
                root[i-1, j] = r

print(w)
print(e)

但是,如果我运行这个,权重w会得到正确计算,但预期的搜索值e仍然“卡住”在它们的初始化值上:

[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
 [ 0.    0.1   0.25  0.35  0.5   0.8 ]
 [ 0.    0.    0.05  0.15  0.3   0.6 ]
 [ 0.    0.    0.    0.05  0.2   0.5 ]
 [ 0.    0.    0.    0.    0.05  0.35]
 [ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
 [ 0.    0.1    inf   inf   inf   inf]
 [ 0.    0.    0.05   inf   inf   inf]
 [ 0.    0.    0.    0.05   inf   inf]
 [ 0.    0.    0.    0.    0.05   inf]
 [ 0.    0.    0.    0.    0.    0.1 ]]

我期望的是ewroot如下:

在此处输入图片说明

我已经调试了几个小时了,但仍然卡住了。 有人能指出上面的 Python 代码有什么问题吗?

在我看来,您在索引中犯了错误。 我无法让它按预期工作,但下面的代码应该给你一个我要去哪里的指示(可能在某个地方一个一个地关闭):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
    return m[i - 1, j - 1]


def set2(m, i, j, v):
    m[i - 1, j - 1] = v


def get1(m, i):
    return m[i - 1]


def set1(m, i, v):
    m[i - 1] = v


e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
    for i in range(n - l + 2):
        j = i + l - 1
        set2(e, i, j, np.inf)
        set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
        for r in range(i, j + 1):
            t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
            if t < get2(e, i, j):
                set2(e, i, j, t)
                set2(root, i, j, r)

print(w)
print(e)

结果:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
 [ 0.    0.2   0.3   0.45  0.7   0.  ]
 [ 0.    0.    0.1   0.25  0.5   0.  ]
 [ 0.    0.    0.    0.15  0.4   0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
 [ 0.    0.2   0.4   0.8   1.35  0.  ]
 [ 0.    0.    0.1   0.35  0.85  0.  ]
 [ 0.    0.    0.    0.15  0.55  0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.7   1.2   1.5   2.    0.    0.3 ]]

最后,我使用了用自定义indexcolumns初始化的pandasSeriesDataFrame对象来强制数组具有与伪代码中相同的索引。 之后,伪代码几乎可以复制粘贴:

import numpy as np
import pandas as pd

P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)

p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)

e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))

for l in range(1, n+1):
    for i in range(1, n-l+2):
        j = i+l-1
        e.set_value(i, j, np.inf)
        w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
        for r in range(i, j+1):
            t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
            if t < e.get_value(i, j):
                e.set_value(i, j, t)
                root.set_value(i, j, r)

print(e)
print(w)
print(root)

这产生了预期的结果:

      0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
      0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
     1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

不过,我仍然对 Numpy 数组的解决方案感兴趣,因为这对我来说似乎更优雅。

库尔特,谢谢你的帖子! 你是我发现的这个问题的唯一有效实现。 我花了很多时间与指数搏斗。 这是我使用 numpy 数组的实现。

import numpy as np
import math

def optimalBST(p,q,n):

    e = np.zeros((n+1)**2).reshape(n+1,n+1)
    w = np.zeros((n+1)**2).reshape(n+1,n+1)
    root = np.zeros((n+1)**2).reshape(n+1,n+1)

    # Initialization
    for i in range(n+1):
        e[i,i] = q[i]
        w[i,i] = q[i]
    for i in range(0,n):
        root[i,i] = i+1

    for l in range(1,n+1):
        for i in range(0, n-l+1):
            j = i+l
            min_ = math.inf
            w[i,j] = w[i,j-1] + p[j] + q[j]
            for r in range(i,j):
                t = e[i, r-1+1] + e[r+1,j] +  w[i,j]
                if t < min_:
                    min_ = t                
                    e[i, j] = t
                    root[i, j-1] = r+1

    root_pruned = np.delete(np.delete(root, n, 1), n, 0)        # Trim last col & row.

    print("------ e -------")
    print(e)
    print("------ w -------")
    print(w)
    print("----- root -----")
    print(root_pruned)

def main():

    p = [0,.15,.1,.05,.1,.2]
    q = [.05,.1,.05,.05,.05,.1]
    n = len(p)-1

    optimalBST(p,q,n)

if __name__ == '__main__':
    main()

输出:

------ e -------
[[0.05 0.45 0.9  1.25 1.75 2.75]
 [0.   0.1  0.4  0.7  1.2  2.  ]
 [0.   0.   0.05 0.25 0.6  1.3 ]
 [0.   0.   0.   0.05 0.3  0.9 ]
 [0.   0.   0.   0.   0.05 0.5 ]
 [0.   0.   0.   0.   0.   0.1 ]]
------ w -------
[[0.05 0.3  0.45 0.55 0.7  1.  ]
 [0.   0.1  0.25 0.35 0.5  0.8 ]
 [0.   0.   0.05 0.15 0.3  0.6 ]
 [0.   0.   0.   0.05 0.2  0.5 ]
 [0.   0.   0.   0.   0.05 0.35]
 [0.   0.   0.   0.   0.   0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
 [0. 2. 2. 2. 4.]
 [0. 0. 3. 4. 5.]
 [0. 0. 0. 4. 5.]
 [0. 0. 0. 0. 5.]]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM