[英]Python dynamic programming performance difference
我正在通過解決 Leetcode 問題來研究動態編程,即使我正在緩存我的結果,我也經常遇到超出時間限制的錯誤。 誰能解釋為什么我的版本比官方版本慢很多?
代碼中存在明顯差異,例如,我使用 class function 進行遞歸,而官方答案沒有。 我的遞歸 function 返回數值,官方沒有,等等。盡管這些似乎都不是有意義的差異,但性能差異仍然是巨大的。
我的版本。 這需要0.177669
秒才能運行,並收到超出時間限制的錯誤。
import datetime as dt
from typing import List
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
self.nums = nums
total = sum(self.nums)
if total % 2 == 1:
return False
half_total = total // 2
return self.traverse(half_total, 0) == 0
@lru_cache(maxsize=None)
def traverse(self, subset_sum, index):
if subset_sum < 0:
return float('inf')
elif index == len(self.nums):
return subset_sum
else:
include = self.traverse(subset_sum - self.nums[index], index + 1)
exclude = self.traverse(subset_sum, index + 1)
best = min(include, exclude)
return best
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
這是官方的回答。 僅需0.000165
秒!
import datetime as dt
from typing import List, Tuple
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
@lru_cache(maxsize=None)
def dfs(nums: Tuple[int], n: int, subset_sum: int) -> bool:
# Base cases
if subset_sum == 0:
return True
if n == 0 or subset_sum < 0:
return False
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
return result
# find sum of array elements
total_sum = sum(nums)
# if total_sum is odd, it cannot be partitioned into equal sum subsets
if total_sum % 2 != 0:
return False
subset_sum = total_sum // 2
n = len(nums)
return dfs(tuple(nums), n - 1, subset_sum)
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
如果您想了解性能,您需要分析您的代碼。 分析可以讓你看到你的代碼在哪里花費時間。
CPython 帶有一個名為cProfile
的內置分析模塊。 但是您可能想看看例如line_profiler 。
在以前的版本中,搜索所有可能的情況。 而在后者中,算法在找到可行的解決方案時停止。
在第一個版本中:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Suppose {include} is zero, the answer is already obtained,
# but the algorithm still try to compute {exclude}, which is not neccessary.
exclude = self.traverse(subset_sum, index + 1)
在第二個版本中:
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
# Because of the short-circuit behavior of logical operator,
# if the first branch has already obtained the solution,
# the second branch will not be executed.
只需添加一個 if-check 即可提高性能:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Check whether we are already done:
if include == 0:
return include
exclude = self.traverse(subset_sum, index + 1)
min
,您必須將兩個遞歸運行到最深的水平。 通過使用布爾值,您可以快捷方式處理此過程,而其他程序使用該快捷方式。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.