[英]Python dynamic programming performance difference
I'm studying dynamic programming by doing Leetcode problems, and I frequently face time limit exceeded errors even though I'm caching my results.我正在通过解决 Leetcode 问题来研究动态编程,即使我正在缓存我的结果,我也经常遇到超出时间限制的错误。 Can anyone explain why my version is so much slower than the official version for this problem?谁能解释为什么我的版本比官方版本慢很多?
There are obviously differences in the code, eg, I use a class function for recursion while the official answer does not.代码中存在明显差异,例如,我使用 class function 进行递归,而官方答案没有。 My recursive function returns numeric values, the official one does not, etc. None of these seem like meaningful differences though, but the performance difference is nonetheless dramatic.我的递归 function 返回数值,官方没有,等等。尽管这些似乎都不是有意义的差异,但性能差异仍然是巨大的。
My version.我的版本。 This takes 0.177669
seconds to run, and receives a time limit exceeded error.这需要0.177669
秒才能运行,并收到超出时间限制的错误。
import datetime as dt
from typing import List
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
self.nums = nums
total = sum(self.nums)
if total % 2 == 1:
return False
half_total = total // 2
return self.traverse(half_total, 0) == 0
@lru_cache(maxsize=None)
def traverse(self, subset_sum, index):
if subset_sum < 0:
return float('inf')
elif index == len(self.nums):
return subset_sum
else:
include = self.traverse(subset_sum - self.nums[index], index + 1)
exclude = self.traverse(subset_sum, index + 1)
best = min(include, exclude)
return best
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
This is the official answer.这是官方的回答。 It takes only 0.000165
seconds!仅需0.000165
秒!
import datetime as dt
from typing import List, Tuple
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
@lru_cache(maxsize=None)
def dfs(nums: Tuple[int], n: int, subset_sum: int) -> bool:
# Base cases
if subset_sum == 0:
return True
if n == 0 or subset_sum < 0:
return False
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
return result
# find sum of array elements
total_sum = sum(nums)
# if total_sum is odd, it cannot be partitioned into equal sum subsets
if total_sum % 2 != 0:
return False
subset_sum = total_sum // 2
n = len(nums)
return dfs(tuple(nums), n - 1, subset_sum)
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
It you want to know about performance, you need to profile your code.如果您想了解性能,您需要分析您的代码。 Profiling lets you see where your code spends its time.分析可以让你看到你的代码在哪里花费时间。
CPython comes with built-in profiling module called cProfile
. CPython 带有一个名为cProfile
的内置分析模块。 But you might want to look at eg line_profiler .但是您可能想看看例如line_profiler 。
In the former version, all possible cases are searched.在以前的版本中,搜索所有可能的情况。 While in the latter, the algorithm stops when a feasible solution has been found.而在后者中,算法在找到可行的解决方案时停止。
In the first version:在第一个版本中:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Suppose {include} is zero, the answer is already obtained,
# but the algorithm still try to compute {exclude}, which is not neccessary.
exclude = self.traverse(subset_sum, index + 1)
In the second version:在第二个版本中:
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
# Because of the short-circuit behavior of logical operator,
# if the first branch has already obtained the solution,
# the second branch will not be executed.
Just adding a if-check will improve the performance:只需添加一个 if-check 即可提高性能:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Check whether we are already done:
if include == 0:
return include
exclude = self.traverse(subset_sum, index + 1)
min
on two recursively gained result where you have to run two recursion to their deepest levels.出于与上述相同的原因,您在两个递归获得的结果上使用比较 function min
,您必须将两个递归运行到最深的水平。 by using booleans, you can shortcut this process and this other program uses that shortcut.通过使用布尔值,您可以快捷方式处理此过程,而其他程序使用该快捷方式。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.