[英]How to better implement recurrence relations in code?
我已經為這個問題苦苦掙扎了兩天了,我還是Python和更多數學密集型編碼的新手,所以請多多指教,只要指出正確的方向即可:)
所以問題是這樣的:
您有一張有效期為N天的電影通行證。 您可以使用任意方式使用它, 連續3天或更長時間除外。
因此,基本上,您可以在特定日期使用通行證,也可以選擇不使用通行證,這意味着將2個提高至N個可能性。 領取通行證的有效方法然后提高為N-2個無效案例
您必須找到有效案例數%(10 ^ 9 + 7)
我發現無效案例的遞歸關系看起來像
invalidCases(at_N)= 2 ^(n-4)+ 2 * invalidCases(at_N-1)-invalidCases(at_n-4)
所以我的第一個沖動就是簡單地使用遞歸:
def invalidCases(n):
if(n<3):
return 0;
elif(n==3):
return 1;
else:
return 2**(n-4)+ 2*invalidCases(n-1)- invalidCases(n-4)
效率很低,但我的公式似乎正確。 我的下一次嘗試是嘗試進行記憶,但是在N = 1006時我一直遇到錯誤。 因此,我更改了遞歸限制。
我目前的嘗試(帶有記憶和遞歸限制)
import sys
sys.setrecursionlimit(10**6)
T=int(input());
#2**(n-4) + 2*ans(n-1)-ans(n-4)
memo={0:0,1:0,2:0,3:1,4:3,5:8,6:20,7:47} #
def ans(n):
#complete this function
if n not in memo:
memo[n]=(2**(n-4) + 2*ans(n-1)-ans(n-4));
return memo[n];
modulo = 10**9 + 7;
print((2**n-ans(n))%modulo);
最后,我的問題。 我需要此代碼才能用於n = 999999。
如何將最壞的情況降到最低? 任何指示或技巧都很好。
這是一個完整的解決方案,它基於以下觀察結果:三天或更長時間的有效解決方案必須從以下一項開始:
0
10
110
其中1表示當天使用了通行證,0表示未使用。
第一種形式存在有效(n-1)可能性,第二種形式存在有效(n-2)可能性,第三種形式存在有效(n-3)可能性。
然后重復是:
有效(n)=有效(n-1)+有效(n-2)+有效(n-3)
基本情況為有效(0)= 1,有效(1)= 2和有效(2)=4。必須注意,有效(0)為1,而不是零。 那是因為當n = 0時,只有一個解決方案,即空序列。 這不僅在數學上是正確的,而且還需要遞歸才能正確運行。
該代碼執行三件事以使其快速運行:
這是代碼:
cache = {}
modulus = 10**9 + 7
def valid(n):
if n in cache:
return cache[n]
if n == 0:
v = 1
elif n == 1:
v = 2
elif n == 2:
v = 4
else:
v = valid(n-1) + valid(n-2) + valid(n-3)
v %= modulus
cache[n] = v
return v
def main():
# Preload the cache
for n in range(1000000):
valid(n)
print(valid(999999))
main()
這是輸出:
746580045
它在我的系統上運行不到2秒。
更新:這是一個最小的迭代解決方案,其靈感來自MFisherKDX使用的方法。 種子值的構建方式消除了特殊包裝的需要(初始v2為有效(0)):
modulus = 10**9 + 7
def valid(n):
v0, v1, v2 = 0, 1, 1
for i in range(n):
v0, v1, v2 = v1, v2, (v0 + v1 + v2) % modulus
return v2
print(valid(999999))
此解決方案可能會盡快獲得。 使用中間結果后,它會丟棄中間結果,如果您只調用一次函數,這很好。
這是我的答案。 自下而上的解決方案。 與湯姆自上而下且同樣有效的答案進行比較。 在第j
天,它會跟蹤在第j
天使用通行證的可能性的數量以及在j
和j-1
兩者上使用通行證的可能性的數量。
def ans(n):
total = 1
tcd = 0 #total used at current day
tcpd = 0 #total used at current and previous day
m = 1000000007
for j in range(0, n):
next_tot = 2*total - tcpd
next_tcd = total - tcpd
next_tcpd = tcd - tcpd
total = next_tot % m
tcd = next_tcd % m
tcpd = next_tcpd % m
return total
print(ans(999999))
結果是746580045
,在我的系統上花費了400ms。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.