[英]Haskell Performance Optimization
I am writing code to find nth Ramanujan-Hardy number. 我正在编写代码来查找第n个Ramanujan-Hardy号码。 Ramanujan-Hardy number is defined as
Ramanujan-Hardy数字定义为
n = a^3 + b^3 = c^3 + d^3
means n can be expressed as sum of two cubes. 意味着n可以表示为两个立方体的总和。
I wrote the following code in haskell: 我在haskell中编写了以下代码:
-- my own implementation for cube root. Expected time complexity is O(n^(1/3))
cube_root n = chelper 1 n
where
chelper i n = if i*i*i > n then (i-1) else chelper (i+1) n
-- It checks if the given number can be expressed as a^3 + b^3 = c^3 + d^3 (is Ramanujan-Hardy number?)
is_ram n = length [a| a<-[1..crn], b<-[(a+1)..crn], c<-[(a+1)..crn], d<-[(c+1)..crn], a*a*a + b*b*b == n && c*c*c + d*d*d == n] /= 0
where
crn = cube_root n
-- It finds nth Ramanujan number by iterating from 1 till the nth number is found. In recursion, if x is Ramanujan number, decrement n. else increment x. If x is 0, preceding number was desired Ramanujan number.
ram n = give_ram 1 n
where
give_ram x 0 = (x-1)
give_ram x n = if is_ram x then give_ram (x+1) (n-1) else give_ram (x+1) n
In my opinion, time complexity to check if a number is Ramanujan number is O(n^(4/3)). 在我看来,检查一个数字是否是Ramanujan数的时间复杂度是O(n ^(4/3))。
On running this code in ghci, it is taking time even to find 2nd Ramanujan number. 在ghci中运行此代码时,即使找到第二个Ramanujan数字也需要时间。
What are possible ways to optimize this code? 有哪些方法可以优化此代码?
First a small clarification of what we're looking for. 首先是对我们正在寻找的内容的一个小小的澄清。 A Ramanujan-Hardy number is one which may be written two different ways as a sum of two cubes, ie a^3+b^3 = c^3 + d^3 where a < b and a < c < d.
Ramanujan-Hardy数是可以以两种不同方式写成的数字,作为两个立方体的总和,即a ^ 3 + b ^ 3 = c ^ 3 + d ^ 3,其中a <b且a <c <d。
An obvious idea is to generate all of the cube-sums in sorted order and then look for adjacent sums which are the same. 一个明显的想法是按排序顺序生成所有立方数和,然后查找相同的相邻和。
Here's a start - a function which generates all of the cube sums with a given first cube: 这是一个开始 - 一个函数,它使用给定的第一个多维数据集生成所有多维数据集总和:
cubes a = [ (a^3+b^3, a, b) | b <- [a+1..] ]
All of the possible cube sums in order is just: 所有可能的多维数据集总和顺序是:
allcubes = sort $ concat [ cubes 1, cubes 2, cubes 3, ... ]
but of course this won't work since concat
and sort
don't work on infinite lists. 但是当然这不起作用,因为
concat
和sort
不能在无限列表上工作。 However, since cubes a
is an increasing sequence we can sort all of the sequences together by merging them: 但是,由于
cubes a
是一个递增的序列,我们可以通过合并它们将所有序列排序在一起:
allcubes = cubes 1 `merge` cubes 2 `merge` cubes 3 `merge` ...
Here we are taking advantage of Haskell's lazy evaluation. 在这里,我们正在利用Haskell的懒惰评估。 The definition of
merge
is just: merge
的定义只是:
merge [] bs = bs
merge as [] = as
merge as@(a:at) bs@(b:bt)
= case compare a b of
LT -> a : merge at bs
EQ -> a : b : merge at bt
GT -> b : merge as bt
We still have a problem since we don't know where to stop. 我们仍然有问题,因为我们不知道在哪里停止。 We can solve that by having
cubes a
initiate cubes (a+1)
at the appropriate time, ie 我们可以通过在适当的时间使
cubes a
启动cubes (a+1)
来解决这个问题
cubes a = ...an initial part... ++ (...the rest... `merge` cubes (a+1) )
The definition is accomplished using span
: 使用
span
完成定义:
cubes a = first ++ (rest `merge` cubes (a+1))
where
s = (a+1)^3 + (a+2)^3
(first, rest) = span (\(x,_,_) -> x < s) [ (a^3+b^3,a,b) | b <- [a+1..]]
So now cubes 1
is the infinite series of all the possible sums a^3 + b^3 where a < b in sorted order. 所以现在
cubes 1
是所有可能总和的无限系列a ^ 3 + b ^ 3其中a <b按排序顺序排列。
To find the Ramanujan-Hardy numbers, we just group adjacent elements of the list together which have the same first component: 要找到Ramanujan-Hardy数,我们只需将列表中的相邻元素组合在一起,它们具有相同的第一个组件:
sameSum (x,a,b) (y,c,d) = x == y
rjgroups = groupBy sameSum $ cubes 1
The groups we are interested in are those whose length is > 1: 我们感兴趣的群体是那些长度> 1的群体:
rjnumbers = filter (\g -> length g > 1) rjgroups
Thre first 10 solutions are: 前10个解决方案是:
ghci> take 10 rjnumbers
[(1729,1,12),(1729,9,10)]
[(4104,2,16),(4104,9,15)]
[(13832,2,24),(13832,18,20)]
[(20683,10,27),(20683,19,24)]
[(32832,4,32),(32832,18,30)]
[(39312,2,34),(39312,15,33)]
[(40033,9,34),(40033,16,33)]
[(46683,3,36),(46683,27,30)]
[(64232,17,39),(64232,26,36)]
[(65728,12,40),(65728,31,33)]
Your is_ram function checks for a Ramanujan number by trying all values for a,b,c,d up to the cuberoot, and then looping over all n. 你的is_ram函数通过尝试a,b,c,d的所有值直到cuberoot,然后遍历所有n来检查Ramanujan数。
An alternative approach would be to simply loop over values for a and b up to some limit and increment an array at index a^3+b^3 by 1 for each choice. 另一种方法是简单地将a和b的值循环到某个限制,并且对于每个选择,将索引a ^ 3 + b ^ 3处的数组递增1。
The Ramanujan numbers can then be found by iterating over non-zero values in this array and returning places where the array content is >=2 (meaning that at least 2 ways have been found of computing that result). 然后可以通过迭代此数组中的非零值并返回数组内容> = 2的位置(意味着已经找到至少2种计算结果的方法)来找到Ramanujan数。
I believe this would be O(n^(2/3)) compared to your method that is O(nn^(4/3)). 我相信这将是O(n ^(2/3)),而你的方法是O(nn ^(4/3))。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.