[英]Generating an optimal binary search tree (Cormen)
我正在閱讀 Cormen 等人的算法簡介(第 3 版)( PDF ),關於最佳二叉搜索樹的第 15.4 節,但是在 Python 中實現optimal_bst
函數的偽代碼時遇到了一些問題。
這是我嘗試將最佳 BST 應用於的示例:
讓我們將e[i,j]
定義為搜索包含從i
到j
標記的鍵的最優二叉搜索樹的預期成本。 最終,我們希望計算e[1, n]
,其中n
是鍵的數量(在本例中為 5)。 最終的遞歸公式是:
這應該由以下偽代碼實現:
請注意,偽代碼可互換地使用基於 1 和基於 0 的索引,而 Python 僅使用后者。 因此,我在實現偽代碼時遇到了麻煩。 這是我到目前為止所擁有的:
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
for i in range(n-l+1):
j = i + l
e[i, j] = np.inf
w[i, j] = w[i, j-1] + p[j-1] + q[j]
for r in range(i, j+1):
t = e[i-1, r-1] + e[r, j] + w[i-1, j]
if t < e[i-1, j]:
e[i-1, j] = t
root[i-1, j] = r
print(w)
print(e)
但是,如果我運行這個,權重w
會得到正確計算,但預期的搜索值e
仍然“卡住”在它們的初始化值上:
[[ 0.05 0.3 0.45 0.55 0.7 1. ]
[ 0. 0.1 0.25 0.35 0.5 0.8 ]
[ 0. 0. 0.05 0.15 0.3 0.6 ]
[ 0. 0. 0. 0.05 0.2 0.5 ]
[ 0. 0. 0. 0. 0.05 0.35]
[ 0. 0. 0. 0. 0. 0.1 ]]
[[ 0.05 inf inf inf inf inf]
[ 0. 0.1 inf inf inf inf]
[ 0. 0. 0.05 inf inf inf]
[ 0. 0. 0. 0.05 inf inf]
[ 0. 0. 0. 0. 0.05 inf]
[ 0. 0. 0. 0. 0. 0.1 ]]
我期望的是e
、 w
和root
如下:
我已經調試了幾個小時了,但仍然卡住了。 有人能指出上面的 Python 代碼有什么問題嗎?
在我看來,您在索引中犯了錯誤。 我無法讓它按預期工作,但下面的代碼應該給你一個我要去哪里的指示(可能在某個地方一個一個地關閉):
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
def get2(m, i, j):
return m[i - 1, j - 1]
def set2(m, i, j, v):
m[i - 1, j - 1] = v
def get1(m, i):
return m[i - 1]
def set1(m, i, v):
m[i - 1] = v
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
for i in range(n - l + 2):
j = i + l - 1
set2(e, i, j, np.inf)
set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
for r in range(i, j + 1):
t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
if t < get2(e, i, j):
set2(e, i, j, t)
set2(root, i, j, r)
print(w)
print(e)
結果:
[[ 0.2 0.4 0.5 0.65 0.9 0. ]
[ 0. 0.2 0.3 0.45 0.7 0. ]
[ 0. 0. 0.1 0.25 0.5 0. ]
[ 0. 0. 0. 0.15 0.4 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.5 0.7 0.8 0.95 0. 0.3 ]]
[[ 0.2 0.6 0.8 1.2 1.95 0. ]
[ 0. 0.2 0.4 0.8 1.35 0. ]
[ 0. 0. 0.1 0.35 0.85 0. ]
[ 0. 0. 0. 0.15 0.55 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.7 1.2 1.5 2. 0. 0.3 ]]
最后,我使用了用自定義index
和columns
初始化的pandas的Series
和DataFrame
對象來強制數組具有與偽代碼中相同的索引。 之后,偽代碼幾乎可以復制粘貼:
import numpy as np
import pandas as pd
P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)
p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)
e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))
for l in range(1, n+1):
for i in range(1, n-l+2):
j = i+l-1
e.set_value(i, j, np.inf)
w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
for r in range(i, j+1):
t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
if t < e.get_value(i, j):
e.set_value(i, j, t)
root.set_value(i, j, r)
print(e)
print(w)
print(root)
這產生了預期的結果:
0 1 2 3 4 5
1 0.05 0.45 0.90 1.25 1.75 2.75
2 0.00 0.10 0.40 0.70 1.20 2.00
3 0.00 0.00 0.05 0.25 0.60 1.30
4 0.00 0.00 0.00 0.05 0.30 0.90
5 0.00 0.00 0.00 0.00 0.05 0.50
6 0.00 0.00 0.00 0.00 0.00 0.10
0 1 2 3 4 5
1 0.05 0.3 0.45 0.55 0.70 1.00
2 0.00 0.1 0.25 0.35 0.50 0.80
3 0.00 0.0 0.05 0.15 0.30 0.60
4 0.00 0.0 0.00 0.05 0.20 0.50
5 0.00 0.0 0.00 0.00 0.05 0.35
6 0.00 0.0 0.00 0.00 0.00 0.10
1 2 3 4 5
1 1.0 1.0 2.0 2.0 2.0
2 0.0 2.0 2.0 2.0 4.0
3 0.0 0.0 3.0 4.0 5.0
4 0.0 0.0 0.0 4.0 5.0
5 0.0 0.0 0.0 0.0 5.0
不過,我仍然對 Numpy 數組的解決方案感興趣,因為這對我來說似乎更優雅。
庫爾特,謝謝你的帖子! 你是我發現的這個問題的唯一有效實現。 我花了很多時間與指數搏斗。 這是我使用 numpy 數組的實現。
import numpy as np
import math
def optimalBST(p,q,n):
e = np.zeros((n+1)**2).reshape(n+1,n+1)
w = np.zeros((n+1)**2).reshape(n+1,n+1)
root = np.zeros((n+1)**2).reshape(n+1,n+1)
# Initialization
for i in range(n+1):
e[i,i] = q[i]
w[i,i] = q[i]
for i in range(0,n):
root[i,i] = i+1
for l in range(1,n+1):
for i in range(0, n-l+1):
j = i+l
min_ = math.inf
w[i,j] = w[i,j-1] + p[j] + q[j]
for r in range(i,j):
t = e[i, r-1+1] + e[r+1,j] + w[i,j]
if t < min_:
min_ = t
e[i, j] = t
root[i, j-1] = r+1
root_pruned = np.delete(np.delete(root, n, 1), n, 0) # Trim last col & row.
print("------ e -------")
print(e)
print("------ w -------")
print(w)
print("----- root -----")
print(root_pruned)
def main():
p = [0,.15,.1,.05,.1,.2]
q = [.05,.1,.05,.05,.05,.1]
n = len(p)-1
optimalBST(p,q,n)
if __name__ == '__main__':
main()
輸出:
------ e -------
[[0.05 0.45 0.9 1.25 1.75 2.75]
[0. 0.1 0.4 0.7 1.2 2. ]
[0. 0. 0.05 0.25 0.6 1.3 ]
[0. 0. 0. 0.05 0.3 0.9 ]
[0. 0. 0. 0. 0.05 0.5 ]
[0. 0. 0. 0. 0. 0.1 ]]
------ w -------
[[0.05 0.3 0.45 0.55 0.7 1. ]
[0. 0.1 0.25 0.35 0.5 0.8 ]
[0. 0. 0.05 0.15 0.3 0.6 ]
[0. 0. 0. 0.05 0.2 0.5 ]
[0. 0. 0. 0. 0.05 0.35]
[0. 0. 0. 0. 0. 0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
[0. 2. 2. 2. 4.]
[0. 0. 3. 4. 5.]
[0. 0. 0. 4. 5.]
[0. 0. 0. 0. 5.]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.