[英]Generating an optimal binary search tree (Cormen)
我正在阅读 Cormen 等人的算法简介(第 3 版)( PDF ),关于最佳二叉搜索树的第 15.4 节,但是在 Python 中实现optimal_bst
函数的伪代码时遇到了一些问题。
这是我尝试将最佳 BST 应用于的示例:
让我们将e[i,j]
定义为搜索包含从i
到j
标记的键的最优二叉搜索树的预期成本。 最终,我们希望计算e[1, n]
,其中n
是键的数量(在本例中为 5)。 最终的递归公式是:
这应该由以下伪代码实现:
请注意,伪代码可互换地使用基于 1 和基于 0 的索引,而 Python 仅使用后者。 因此,我在实现伪代码时遇到了麻烦。 这是我到目前为止所拥有的:
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
for i in range(n-l+1):
j = i + l
e[i, j] = np.inf
w[i, j] = w[i, j-1] + p[j-1] + q[j]
for r in range(i, j+1):
t = e[i-1, r-1] + e[r, j] + w[i-1, j]
if t < e[i-1, j]:
e[i-1, j] = t
root[i-1, j] = r
print(w)
print(e)
但是,如果我运行这个,权重w
会得到正确计算,但预期的搜索值e
仍然“卡住”在它们的初始化值上:
[[ 0.05 0.3 0.45 0.55 0.7 1. ]
[ 0. 0.1 0.25 0.35 0.5 0.8 ]
[ 0. 0. 0.05 0.15 0.3 0.6 ]
[ 0. 0. 0. 0.05 0.2 0.5 ]
[ 0. 0. 0. 0. 0.05 0.35]
[ 0. 0. 0. 0. 0. 0.1 ]]
[[ 0.05 inf inf inf inf inf]
[ 0. 0.1 inf inf inf inf]
[ 0. 0. 0.05 inf inf inf]
[ 0. 0. 0. 0.05 inf inf]
[ 0. 0. 0. 0. 0.05 inf]
[ 0. 0. 0. 0. 0. 0.1 ]]
我期望的是e
、 w
和root
如下:
我已经调试了几个小时了,但仍然卡住了。 有人能指出上面的 Python 代码有什么问题吗?
在我看来,您在索引中犯了错误。 我无法让它按预期工作,但下面的代码应该给你一个我要去哪里的指示(可能在某个地方一个一个地关闭):
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
def get2(m, i, j):
return m[i - 1, j - 1]
def set2(m, i, j, v):
m[i - 1, j - 1] = v
def get1(m, i):
return m[i - 1]
def set1(m, i, v):
m[i - 1] = v
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
for i in range(n - l + 2):
j = i + l - 1
set2(e, i, j, np.inf)
set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
for r in range(i, j + 1):
t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
if t < get2(e, i, j):
set2(e, i, j, t)
set2(root, i, j, r)
print(w)
print(e)
结果:
[[ 0.2 0.4 0.5 0.65 0.9 0. ]
[ 0. 0.2 0.3 0.45 0.7 0. ]
[ 0. 0. 0.1 0.25 0.5 0. ]
[ 0. 0. 0. 0.15 0.4 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.5 0.7 0.8 0.95 0. 0.3 ]]
[[ 0.2 0.6 0.8 1.2 1.95 0. ]
[ 0. 0.2 0.4 0.8 1.35 0. ]
[ 0. 0. 0.1 0.35 0.85 0. ]
[ 0. 0. 0. 0.15 0.55 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.7 1.2 1.5 2. 0. 0.3 ]]
最后,我使用了用自定义index
和columns
初始化的pandas的Series
和DataFrame
对象来强制数组具有与伪代码中相同的索引。 之后,伪代码几乎可以复制粘贴:
import numpy as np
import pandas as pd
P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)
p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)
e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))
for l in range(1, n+1):
for i in range(1, n-l+2):
j = i+l-1
e.set_value(i, j, np.inf)
w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
for r in range(i, j+1):
t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
if t < e.get_value(i, j):
e.set_value(i, j, t)
root.set_value(i, j, r)
print(e)
print(w)
print(root)
这产生了预期的结果:
0 1 2 3 4 5
1 0.05 0.45 0.90 1.25 1.75 2.75
2 0.00 0.10 0.40 0.70 1.20 2.00
3 0.00 0.00 0.05 0.25 0.60 1.30
4 0.00 0.00 0.00 0.05 0.30 0.90
5 0.00 0.00 0.00 0.00 0.05 0.50
6 0.00 0.00 0.00 0.00 0.00 0.10
0 1 2 3 4 5
1 0.05 0.3 0.45 0.55 0.70 1.00
2 0.00 0.1 0.25 0.35 0.50 0.80
3 0.00 0.0 0.05 0.15 0.30 0.60
4 0.00 0.0 0.00 0.05 0.20 0.50
5 0.00 0.0 0.00 0.00 0.05 0.35
6 0.00 0.0 0.00 0.00 0.00 0.10
1 2 3 4 5
1 1.0 1.0 2.0 2.0 2.0
2 0.0 2.0 2.0 2.0 4.0
3 0.0 0.0 3.0 4.0 5.0
4 0.0 0.0 0.0 4.0 5.0
5 0.0 0.0 0.0 0.0 5.0
不过,我仍然对 Numpy 数组的解决方案感兴趣,因为这对我来说似乎更优雅。
库尔特,谢谢你的帖子! 你是我发现的这个问题的唯一有效实现。 我花了很多时间与指数搏斗。 这是我使用 numpy 数组的实现。
import numpy as np
import math
def optimalBST(p,q,n):
e = np.zeros((n+1)**2).reshape(n+1,n+1)
w = np.zeros((n+1)**2).reshape(n+1,n+1)
root = np.zeros((n+1)**2).reshape(n+1,n+1)
# Initialization
for i in range(n+1):
e[i,i] = q[i]
w[i,i] = q[i]
for i in range(0,n):
root[i,i] = i+1
for l in range(1,n+1):
for i in range(0, n-l+1):
j = i+l
min_ = math.inf
w[i,j] = w[i,j-1] + p[j] + q[j]
for r in range(i,j):
t = e[i, r-1+1] + e[r+1,j] + w[i,j]
if t < min_:
min_ = t
e[i, j] = t
root[i, j-1] = r+1
root_pruned = np.delete(np.delete(root, n, 1), n, 0) # Trim last col & row.
print("------ e -------")
print(e)
print("------ w -------")
print(w)
print("----- root -----")
print(root_pruned)
def main():
p = [0,.15,.1,.05,.1,.2]
q = [.05,.1,.05,.05,.05,.1]
n = len(p)-1
optimalBST(p,q,n)
if __name__ == '__main__':
main()
输出:
------ e -------
[[0.05 0.45 0.9 1.25 1.75 2.75]
[0. 0.1 0.4 0.7 1.2 2. ]
[0. 0. 0.05 0.25 0.6 1.3 ]
[0. 0. 0. 0.05 0.3 0.9 ]
[0. 0. 0. 0. 0.05 0.5 ]
[0. 0. 0. 0. 0. 0.1 ]]
------ w -------
[[0.05 0.3 0.45 0.55 0.7 1. ]
[0. 0.1 0.25 0.35 0.5 0.8 ]
[0. 0. 0.05 0.15 0.3 0.6 ]
[0. 0. 0. 0.05 0.2 0.5 ]
[0. 0. 0. 0. 0.05 0.35]
[0. 0. 0. 0. 0. 0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
[0. 2. 2. 2. 4.]
[0. 0. 3. 4. 5.]
[0. 0. 0. 4. 5.]
[0. 0. 0. 0. 5.]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.