[英]Python dynamic programming performance difference
我正在通过解决 Leetcode 问题来研究动态编程,即使我正在缓存我的结果,我也经常遇到超出时间限制的错误。 谁能解释为什么我的版本比官方版本慢很多?
代码中存在明显差异,例如,我使用 class function 进行递归,而官方答案没有。 我的递归 function 返回数值,官方没有,等等。尽管这些似乎都不是有意义的差异,但性能差异仍然是巨大的。
我的版本。 这需要0.177669
秒才能运行,并收到超出时间限制的错误。
import datetime as dt
from typing import List
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
self.nums = nums
total = sum(self.nums)
if total % 2 == 1:
return False
half_total = total // 2
return self.traverse(half_total, 0) == 0
@lru_cache(maxsize=None)
def traverse(self, subset_sum, index):
if subset_sum < 0:
return float('inf')
elif index == len(self.nums):
return subset_sum
else:
include = self.traverse(subset_sum - self.nums[index], index + 1)
exclude = self.traverse(subset_sum, index + 1)
best = min(include, exclude)
return best
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
这是官方的回答。 仅需0.000165
秒!
import datetime as dt
from typing import List, Tuple
from functools import lru_cache
class Solution:
def canPartition(self, nums: List[int]) -> bool:
@lru_cache(maxsize=None)
def dfs(nums: Tuple[int], n: int, subset_sum: int) -> bool:
# Base cases
if subset_sum == 0:
return True
if n == 0 or subset_sum < 0:
return False
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
return result
# find sum of array elements
total_sum = sum(nums)
# if total_sum is odd, it cannot be partitioned into equal sum subsets
if total_sum % 2 != 0:
return False
subset_sum = total_sum // 2
n = len(nums)
return dfs(tuple(nums), n - 1, subset_sum)
test_case = [20,68,68,11,48,18,50,5,3,51,52,11,13,11,38,100,30,87,1,56,85,63,14,96,7,17,54,11,32,61,94,13,85,10,78,57,69,92,66,28,70,20,3,29,10,73,89,86,28,48,69,54,87,11,91,32,59,4,88,20,81,100,29,75,79,82,6,74,66,30,9,6,83,54,54,53,80,94,64,77,22,7,22,26,12,31,23,26,65,65,35,36,34,1,12,44,22,73,59,99]
solution = Solution()
start = dt.datetime.now()
print(solution.canPartition(test_case))
end = dt.datetime.now()
print((end-start).total_seconds())
如果您想了解性能,您需要分析您的代码。 分析可以让你看到你的代码在哪里花费时间。
CPython 带有一个名为cProfile
的内置分析模块。 但是您可能想看看例如line_profiler 。
在以前的版本中,搜索所有可能的情况。 而在后者中,算法在找到可行的解决方案时停止。
在第一个版本中:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Suppose {include} is zero, the answer is already obtained,
# but the algorithm still try to compute {exclude}, which is not neccessary.
exclude = self.traverse(subset_sum, index + 1)
在第二个版本中:
result = (dfs(nums, n - 1, subset_sum - nums[n - 1])
or dfs(nums, n - 1, subset_sum))
# Because of the short-circuit behavior of logical operator,
# if the first branch has already obtained the solution,
# the second branch will not be executed.
只需添加一个 if-check 即可提高性能:
include = self.traverse(subset_sum - self.nums[index], index + 1)
# Check whether we are already done:
if include == 0:
return include
exclude = self.traverse(subset_sum, index + 1)
min
,您必须将两个递归运行到最深的水平。 通过使用布尔值,您可以快捷方式处理此过程,而其他程序使用该快捷方式。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.