[英]Karatsuba algorithm too much recursion
我正在嘗試在 c++ 中實現 Karatsuba 乘法算法,但現在我只是想讓它在 python 中工作。
這是我的代碼:
def mult(x, y, b, m):
if max(x, y) < b:
return x * y
bm = pow(b, m)
x0 = x / bm
x1 = x % bm
y0 = y / bm
y1 = y % bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
我不明白的是:應該如何創建z2
、 z1
和z0
? 遞歸使用mult
function 是否正確? 如果是這樣,我在某個地方搞砸了,因為遞歸沒有停止。
有人可以指出錯誤在哪里嗎?
注意:下面的回復直接解決了 OP 關於過度遞歸的問題,但它並沒有嘗試提供正確的 Karatsuba 算法。 在這方面,其他答復提供的信息要多得多。
試試這個版本:
def mult(x, y, b, m):
bm = pow(b, m)
if min(x, y) <= bm:
return x * y
# NOTE the following 4 lines
x0 = x % bm
x1 = x / bm
y0 = y % bm
y1 = y / bm
z0 = mult(x0, y0, b, m)
z2 = mult(x1, y1, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
return retval
您的版本最嚴重的問題是您對 x0 和 x1 以及 y0 和 y1 的計算被翻轉了。 此外,如果x1
和y1
為 0,則算法的推導不成立,因為在這種情況下,分解步驟變得無效。 因此,您必須通過確保 x 和 y 都大於 b**m 來避免這種可能性。
編輯:修復了代碼中的錯字; 補充說明
編輯2:
為了更清楚,直接評論您的原始版本:
def mult(x, y, b, m):
# The termination condition will never be true when the recursive
# call is either
# mult(z2, bm ** 2, b, m)
# or mult(z1, bm, b, m)
#
# Since every recursive call leads to one of the above, you have an
# infinite recursion condition.
if max(x, y) < b:
return x * y
bm = pow(b, m)
# Even without the recursion problem, the next four lines are wrong
x0 = x / bm # RHS should be x % bm
x1 = x % bm # RHS should be x / bm
y0 = y / bm # RHS should be y % bm
y1 = y % bm # RHS should be y / bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
通常大數存儲為整數的 arrays。 每個 integer 代表一位。 這種方法允許通過數組的簡單左移將任意數字乘以基數的冪。
這是我的基於列表的實現(可能包含錯誤):
def normalize(l,b):
over = 0
for i,x in enumerate(l):
over,l[i] = divmod(x+over,b)
if over: l.append(over)
return l
def sum_lists(x,y,b):
l = min(len(x),len(y))
res = map(operator.add,x[:l],y[:l])
if len(x) > l: res.extend(x[l:])
else: res.extend(y[l:])
return normalize(res,b)
def sub_lists(x,y,b):
res = map(operator.sub,x[:len(y)],y)
res.extend(x[len(y):])
return normalize(res,b)
def lshift(x,n):
if len(x) > 1 or len(x) == 1 and x[0] != 0:
return [0 for i in range(n)] + x
else: return x
def mult_lists(x,y,b):
if min(len(x),len(y)) == 0: return [0]
m = max(len(x),len(y))
if (m == 1): return normalize([x[0]*y[0]],b)
else: m >>= 1
x0,x1 = x[:m],x[m:]
y0,y1 = y[:m],y[m:]
z0 = mult_lists(x0,y0,b)
z1 = mult_lists(x1,y1,b)
z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
t2 = lshift(z1,m*2)
return sum_lists(sum_lists(z0,t1,b),t2,b)
sum_lists
和sub_lists
返回非標准化結果 - 單個數字可以大於基值。 normalize
function 解決了這個問題。
所有函數都希望以相反的順序獲取數字列表。 例如,以 10 為底的 12 應寫為 [2,1]。 讓我們取一個 9987654321 的正方形。
» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
Karatsuba 乘法的目標是通過進行 3 次遞歸調用而不是 4 次來改進分治乘法算法。 因此,腳本中唯一應該包含對乘法的遞歸調用的行是那些分配z0
、 z1
和z2
的行。 其他任何事情都會給您帶來更糟糕的復雜性。 當您還沒有定義乘法(以及更進一步的求冪)時,您也不能使用pow
來計算b m 。
為此,該算法關鍵地利用了它使用位置符號系統這一事實。 如果您有一個以b為底的數字的表示x ,則只需將該表示的數字向左移動m
次即可獲得x*b m 。 對於任何位置符號系統,這種移位操作基本上是“免費的”。 這也意味着如果你想實現它,你必須重現這個位置符號和“自由”移位。 要么您選擇以b=2為基數進行計算並使用 python 的位運算符(或給定十進制、十六進制、...基數的位運算符,如果您的測試平台有它們),或者您決定為教育目的實現一些有效的東西對於任意b ,您可以使用類似字符串、arrays 或 lists 之類的東西來重現此位置算術。
您已經有了一個包含列表的解決方案。 我喜歡使用 python 中的字符串,因為int(s, base)
會給你 integer 對應於字符串s
被視為 base base
中的數字表示:它使測試變得容易。 我在這里發布了一個受到大量評論的基於字符串的實現作為要點,包括字符串到數字和數字到字符串的原語,以便進行很好的衡量。
您可以通過為mult
提供具有基數及其(相等)長度的填充字符串來測試它: arguments
In [169]: mult("987654321","987654321",10,9)
Out[169]: '966551847789971041'
如果您不想計算填充或計算字符串長度,可以使用填充 function 為您完成:
In [170]: padding("987654321","2")
Out[170]: ('987654321', '000000002', 9)
當然它適用於b>10
:
In [171]: mult('987654321', '000000002', 16, 9)
Out[171]: '130eca8642'
(檢查wolfram alpha )
我相信該技術背后的想法是使用遞歸算法計算 z i項,但結果並沒有以這種方式統一在一起。 由於您想要的最終結果是
z0 B^2m + z1 B^m + z2
假設您選擇了合適的 B 值(例如 2),您可以在不進行任何乘法的情況下計算 B^m。 例如,當使用 B = 2 時,您可以使用位移而不是乘法來計算 B^m。 這意味着可以在不進行任何乘法的情況下完成最后一步。
還有一件事——我注意到你為整個算法選擇了一個固定的 m 值。 通常,您可以通過讓 m 始終是一個值來實現此算法,使得當 x 和 y 以 B 為基數寫入時,B^m 是 x 和 y 中位數的一半。如果您使用的是 2 的冪,則可以這樣做通過選擇 m = ceil((log x) / 2)。
希望這可以幫助!
在 Python 2.7 中:將此文件另存為 Karatsuba.py
def karatsuba(x,y):
"""Karatsuba multiplication algorithm.
Return the product of two numbers in an efficient manner
@author Shashank
date: 23-09-2018
Parameters
----------
x : int
First Number
y : int
Second Number
Returns
-------
prod : int
The product of two numbers
Examples
--------
>>> import Karatsuba.karatsuba
>>> a = 1234567899876543211234567899876543211234567899876543211234567890
>>> b = 9876543211234567899876543211234567899876543211234567899876543210
>>> Karatsuba.karatsuba(a,b)
12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
"""
if len(str(x)) == 1 or len(str(y)) == 1:
return x*y
else:
n = max(len(str(x)), len(str(y)))
m = n/2
a = x/10**m
b = x%10**m
c = y/10**m
d = y%10**m
ac = karatsuba(a,c) #step 1
bd = karatsuba(b,d) #step 2
ad_plus_bc = karatsuba(a+b, c+d) - ac - bd #step 3
prod = ac*10**(2*m) + bd + ad_plus_bc*10**m #step 4
return prod
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.