[英]Generating an optimal binary search tree (Cormen)
I'm reading Cormen et al., Introduction to Algorithms (3rd ed.) ( PDF ), section 15.4 on optimal binary search trees, but am having some trouble implementing the pseudocode for the optimal_bst
function in Python.我正在阅读 Cormen 等人的算法简介(第 3 版)( PDF ),关于最佳二叉搜索树的第 15.4 节,但是在 Python 中实现optimal_bst
函数的伪代码时遇到了一些问题。
Here is the example I'm trying to apply the optimal BST to:这是我尝试将最佳 BST 应用于的示例:
Let us define e[i,j]
as the expected cost of searching an optimal binary search tree containing the keys labeled from i
to j
.让我们将e[i,j]
定义为搜索包含从i
到j
标记的键的最优二叉搜索树的预期成本。 Ultimately, we wish to compute e[1, n]
, where n
is the number of keys (5 in this example).最终,我们希望计算e[1, n]
,其中n
是键的数量(在本例中为 5)。 The final recursive formulation is:最终的递归公式是:
which should be implemented by the following pseudocode:这应该由以下伪代码实现:
Notice that the pseudocode interchangeably uses 1- and 0-based indexing, whereas Python uses only the latter.请注意,伪代码可互换地使用基于 1 和基于 0 的索引,而 Python 仅使用后者。 As a consequence I'm having trouble implementing the pseudocode.因此,我在实现伪代码时遇到了麻烦。 Here is what I have so far:这是我到目前为止所拥有的:
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
for i in range(n-l+1):
j = i + l
e[i, j] = np.inf
w[i, j] = w[i, j-1] + p[j-1] + q[j]
for r in range(i, j+1):
t = e[i-1, r-1] + e[r, j] + w[i-1, j]
if t < e[i-1, j]:
e[i-1, j] = t
root[i-1, j] = r
print(w)
print(e)
However, if I run this the weights w
get computed correctly, but the expected search values e
remain 'stuck' at their initialized values:但是,如果我运行这个,权重w
会得到正确计算,但预期的搜索值e
仍然“卡住”在它们的初始化值上:
[[ 0.05 0.3 0.45 0.55 0.7 1. ]
[ 0. 0.1 0.25 0.35 0.5 0.8 ]
[ 0. 0. 0.05 0.15 0.3 0.6 ]
[ 0. 0. 0. 0.05 0.2 0.5 ]
[ 0. 0. 0. 0. 0.05 0.35]
[ 0. 0. 0. 0. 0. 0.1 ]]
[[ 0.05 inf inf inf inf inf]
[ 0. 0.1 inf inf inf inf]
[ 0. 0. 0.05 inf inf inf]
[ 0. 0. 0. 0.05 inf inf]
[ 0. 0. 0. 0. 0.05 inf]
[ 0. 0. 0. 0. 0. 0.1 ]]
What I expect is that e
, w
, and root
be as follows:我期望的是e
、 w
和root
如下:
I've been debugging this for a couple of hours by now and am still stuck.我已经调试了几个小时了,但仍然卡住了。 Can someone point out what is wrong with the Python code above?有人能指出上面的 Python 代码有什么问题吗?
It appears to me that you made a mistake in the indices.在我看来,您在索引中犯了错误。 I couldn't make it work as expected but the following code should give you an indication where I was heading at (there is probably an off by one somewhere):我无法让它按预期工作,但下面的代码应该给你一个我要去哪里的指示(可能在某个地方一个一个地关闭):
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
def get2(m, i, j):
return m[i - 1, j - 1]
def set2(m, i, j, v):
m[i - 1, j - 1] = v
def get1(m, i):
return m[i - 1]
def set1(m, i, v):
m[i - 1] = v
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
for i in range(n - l + 2):
j = i + l - 1
set2(e, i, j, np.inf)
set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
for r in range(i, j + 1):
t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
if t < get2(e, i, j):
set2(e, i, j, t)
set2(root, i, j, r)
print(w)
print(e)
The result:结果:
[[ 0.2 0.4 0.5 0.65 0.9 0. ]
[ 0. 0.2 0.3 0.45 0.7 0. ]
[ 0. 0. 0.1 0.25 0.5 0. ]
[ 0. 0. 0. 0.15 0.4 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.5 0.7 0.8 0.95 0. 0.3 ]]
[[ 0.2 0.6 0.8 1.2 1.95 0. ]
[ 0. 0.2 0.4 0.8 1.35 0. ]
[ 0. 0. 0.1 0.35 0.85 0. ]
[ 0. 0. 0. 0.15 0.55 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.7 1.2 1.5 2. 0. 0.3 ]]
In the end I used pandas ' Series
and DataFrame
objects initialized with custom index
and columns
to coerce the arrays to have the same indexing as in the pseudocode.最后,我使用了用自定义index
和columns
初始化的pandas的Series
和DataFrame
对象来强制数组具有与伪代码中相同的索引。 After that, the pseudocode can be almost copy-pasted:之后,伪代码几乎可以复制粘贴:
import numpy as np
import pandas as pd
P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)
p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)
e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))
for l in range(1, n+1):
for i in range(1, n-l+2):
j = i+l-1
e.set_value(i, j, np.inf)
w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
for r in range(i, j+1):
t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
if t < e.get_value(i, j):
e.set_value(i, j, t)
root.set_value(i, j, r)
print(e)
print(w)
print(root)
which yields the expected results:这产生了预期的结果:
0 1 2 3 4 5
1 0.05 0.45 0.90 1.25 1.75 2.75
2 0.00 0.10 0.40 0.70 1.20 2.00
3 0.00 0.00 0.05 0.25 0.60 1.30
4 0.00 0.00 0.00 0.05 0.30 0.90
5 0.00 0.00 0.00 0.00 0.05 0.50
6 0.00 0.00 0.00 0.00 0.00 0.10
0 1 2 3 4 5
1 0.05 0.3 0.45 0.55 0.70 1.00
2 0.00 0.1 0.25 0.35 0.50 0.80
3 0.00 0.0 0.05 0.15 0.30 0.60
4 0.00 0.0 0.00 0.05 0.20 0.50
5 0.00 0.0 0.00 0.00 0.05 0.35
6 0.00 0.0 0.00 0.00 0.00 0.10
1 2 3 4 5
1 1.0 1.0 2.0 2.0 2.0
2 0.0 2.0 2.0 2.0 4.0
3 0.0 0.0 3.0 4.0 5.0
4 0.0 0.0 0.0 4.0 5.0
5 0.0 0.0 0.0 0.0 5.0
I would still be interested in a solution with Numpy arrays, though, as this seems more elegant to me.不过,我仍然对 Numpy 数组的解决方案感兴趣,因为这对我来说似乎更优雅。
Kurt, Thanks for your post!库尔特,谢谢你的帖子! Yours is the only working implementation of this problem I found.你是我发现的这个问题的唯一有效实现。 I spent mucho tiempo wrestling with the indices.我花了很多时间与指数搏斗。 Here is my implementation with numpy arrays.这是我使用 numpy 数组的实现。
import numpy as np
import math
def optimalBST(p,q,n):
e = np.zeros((n+1)**2).reshape(n+1,n+1)
w = np.zeros((n+1)**2).reshape(n+1,n+1)
root = np.zeros((n+1)**2).reshape(n+1,n+1)
# Initialization
for i in range(n+1):
e[i,i] = q[i]
w[i,i] = q[i]
for i in range(0,n):
root[i,i] = i+1
for l in range(1,n+1):
for i in range(0, n-l+1):
j = i+l
min_ = math.inf
w[i,j] = w[i,j-1] + p[j] + q[j]
for r in range(i,j):
t = e[i, r-1+1] + e[r+1,j] + w[i,j]
if t < min_:
min_ = t
e[i, j] = t
root[i, j-1] = r+1
root_pruned = np.delete(np.delete(root, n, 1), n, 0) # Trim last col & row.
print("------ e -------")
print(e)
print("------ w -------")
print(w)
print("----- root -----")
print(root_pruned)
def main():
p = [0,.15,.1,.05,.1,.2]
q = [.05,.1,.05,.05,.05,.1]
n = len(p)-1
optimalBST(p,q,n)
if __name__ == '__main__':
main()
Output:输出:
------ e -------
[[0.05 0.45 0.9 1.25 1.75 2.75]
[0. 0.1 0.4 0.7 1.2 2. ]
[0. 0. 0.05 0.25 0.6 1.3 ]
[0. 0. 0. 0.05 0.3 0.9 ]
[0. 0. 0. 0. 0.05 0.5 ]
[0. 0. 0. 0. 0. 0.1 ]]
------ w -------
[[0.05 0.3 0.45 0.55 0.7 1. ]
[0. 0.1 0.25 0.35 0.5 0.8 ]
[0. 0. 0.05 0.15 0.3 0.6 ]
[0. 0. 0. 0.05 0.2 0.5 ]
[0. 0. 0. 0. 0.05 0.35]
[0. 0. 0. 0. 0. 0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
[0. 2. 2. 2. 4.]
[0. 0. 3. 4. 5.]
[0. 0. 0. 4. 5.]
[0. 0. 0. 0. 5.]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.