[英]How do I optimize this algorithm so it doesn't exceed the given time limit?
首先,我正在尝试使用 Python 解决以下问题:
等差数列是 a, a+b, a+2b, ..., a+nb 形式的序列,其中 n=0, 1, 2, 3, ... 。 对于这个问题,a 是非负 integer,b 是正 integer。
编写一个程序,在双平方集合 S 中找到所有长度为 n 的算术级数。 双平方集合定义为所有形式为 p 2 + q 2的整数的集合(其中 p 和 q 是非负整数)。
时间限制:5秒
输入格式:
第 1 行:N (3 <= N <= 25),要搜索的级数长度 第 2 行:M (1 <= M <= 250),将搜索限制在 0 <= 的二平方的上限p,q <= M。样本输入:
5 7
Output 格式:
如果没有找到序列,则单行读取NONE
。 否则,output 一行或多行,每行有两个整数:找到的序列中的第一个元素和同一序列中连续元素之间的差异。 这些行应该首先以最小差异序列和这些序列中最小的起始编号进行排序。将不会超过 10,000 个序列。
样品 output:
37 4 2 8 29 8 1 12 5 12 13 12 17 12 5 20 2 24
我编写的代码确实有效,但它大大超出了给定的时间限制。 我不知道这是算法本身引起的问题,还是只是 Python 引起的问题。 有人可以建议一种在 5 秒内运行的方法吗? 这是代码:
fin = open ('ariprog.in', 'r')
fout = open ('ariprog.out', 'w')
N:int=int(fin.readline().strip())#take N
M:int=int(fin.readline().strip())#take M
s:set=set()#set s
ans:list=[]#the list that will contain the pairs
mx:int=M**2 * 2#absolute max value in the set s
for j in range(M+1):#produce the set s
for i in range(j,M+1):
h:int=(i**2)+(j**2)
s.add(h)
for stepVal in range(1,(mx//(N-1))+1):#iterates over the possible step values
for initial in s:#iterates over the possible starting points in the set s
count:int=1
k:int=initial
while count<N:
if k+stepVal not in s:break #if the loop breaks,
k+=stepVal #we don't add the pair to the answer list
count+=1
else:ans.append([initial,stepVal])
ans.sort(key=lambda x:x[1]) #sort the answer list
if not ans:fout.write("NONE" + "\n")
for i,e in ans:
pr=f'{str(i)} {str(e)}\n'
fout.write(pr)
fout.close()
当出现测试用例 7 时,我收到以下消息:
> Run 7: Execution error: Your program (`ariprog') used more than
the allotted runtime of 5 seconds (it ended or was stopped at
5.242 seconds) when presented with test case 7. It used 10368 KB
of memory.
------ Data for Run 7 [length=7 bytes] ------
21
200
----------------------------
内部循环可以优化 - 但它需要更多的逻辑。
这个想法是,在检查initial
之后,您将根据结果知道initial+stepVal
也没有序列,或者您不必检查前N-1
数字(因为它们已经被检查过)。
而且您似乎按顺序生成结果,因此无需对它们进行排序。
我删除了变量定义以保存 memory
我还尽可能使用列表推导。
fin = open ('ariprog.in', 'r')
fout = open ('ariprog.out', 'w')
ans:list=[]#the list that will contain the pairs
s = set([(i**2)+(j**2) for j in range(int(fin.readline().strip())+1) for i in range(j,int(fin.readline().strip())+1)])
for stepVal in range(1,(int(fin.readline().strip())**2 * 2//(int(fin.readline().strip())-1))+1):#iterates over the possible step values
for initial in s:#iterates over the possible starting points in the set s
count:int=1
k:int=initial
while count<int(fin.readline().strip()):
if k+stepVal not in s:break #if the loop breaks,
k+=stepVal #we don't add the pair to the answer list
count+=1
else:ans.append([initial,stepVal])
ans.sort(key=lambda x:x[1]) #sort the answer list
if not ans:fout.write("NONE" + "\n")
[fout.write (f'{str(i)} {str(e)}\n') for i,e in ans]
fout.close()
可以稍微加快速度的方法是将解决方案作为每个步长的整个双平方集的渐进细化/过滤器来处理。 这可以将执行时间减少大约 20%
然而,我发现真正的速度提升可能来自一个猜想(我还无法证明),其中只有 4 的倍数的跨步可以存在于 4 个或更多元素的序列中。
def findSequences(N,M,step=4):
S = { p*p+q*q for p in range(M+1) for q in range(p,M+1) }
if N<4 : step = 1
for b in range(step,2*M*M//(N-1)+1,step):
eligible = S
for n in range(1,N):
eligible = eligible.intersection(e-b for e in eligible)
if not eligible: break
for a in sorted(eligible):
yield a,b
print(*findSequences(5,7) )
# (1, 4) (37, 4) (2, 8) (29, 8) (1, 12) (5, 12) (13, 12) (17, 12) (5, 20) (2, 24)
print(*findSequences(21,200) )
# (1217, 84) (2434, 168) (4868, 336) (6085, 420) (9736, 672) (10953, 756) (12170, 840) (12953, 924) (15821, 1092)
在我的笔记本电脑上,这个猜想将执行时间缩短了 75%
也许对数论有更深入了解的人可以阐明我关于 4 的倍数的假设。
尽管我无法解释为什么会这样,但该猜想适用于 M <= 250,如下所示:
all(b%4==0 for a,b in findSequences(4,250,1)) # returns True
并且由于较长的序列是较短序列的扩展,因此它适用于 N = 5,6,7, ...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.