简体   繁体   中英

Python: efficient way to find palindrome numbers that are sums of consecutive squares

This is a problem from codewars .

Given an integer N, write a function values that finds how many numbers in the interval (1,...,N) are palindromes and can be expressed as a sum of consecutive squares.

For example, values(100) will return 3. In fact, the numbers smaller than 100 that have the above property are:

  • 5 = 1+4,
  • 55 = 1+4+9+16+25
  • 77 = 16+25+36

Palindromes that are perfect squares such as 4, 9, 121, etc do not count. Tests are up to N=10^7.

I can solve the problem and pass all the test cases, but I cannot meet the efficiency requirements (the process is killed after 12s).

Am I simply missing some key observation, or there is something wrong with the efficiency of my code below? I broke it down a bit for better readability.

from numpy import cumsum

def is_palindrome(n):
    return str(n) == str(n)[::-1]

def values(n):

    limit = int(n**0.5) # could be improved
    pals = []

    for i in range(1,limit):
        # take numbers in the cumsum that are palindromes < n
        sums = [p for p in cumsum([i**2 for i in range(i,limit)]) if p<n and is_palindrome(p)]
        # remove perfect-squares and those already counted
        pals.extend(k for k in sums if not (k in pals) and not (k**0.5).is_integer())

    return len(pals)

Note: I'm aware that checking that a number is a perfect square using sqrt(n).is_integer() might not be perfect but should be enough for this case.

Apart from working on computational efficiency, you could improve your algorithm strategy. There is a formula for the sum of all squares s(n)=1²+2²+...+n² = n(n+1)(2n+1)/6. So instead of adding m²+(m+1)²+...+n², you could calculate s(n)-s(m-1). A list for all s(n) to find all possible pairs with itertools and subtract them should speed up your program.

More information about pyramidal numbers

Following @nm suggestion, you could pre-compute all the values for n<=10e7 , and return the number of matches:

import bisect   

def value(n):
    vals = [5, 55, 77, 181, 313, 434, 505, 545, 595, 636, 818, 1001, 1111, 1441, 1771, 4334, 6446, 17371, 17871, 19691, 21712, 41214, 42924, 44444, 46564, 51015, 65756, 81818, 97679, 99199, 108801, 127721, 137731, 138831, 139931, 148841, 161161, 166661, 171171, 188881, 191191, 363363, 435534, 444444, 485584, 494494, 525525, 554455, 629926, 635536, 646646, 656656, 904409, 923329, 944449, 964469, 972279, 981189, 982289, 1077701, 1224221, 1365631, 1681861, 1690961, 1949491, 1972791, 1992991, 2176712, 2904092, 3015103, 3162613, 3187813, 3242423, 3628263, 4211124, 4338334, 4424244, 4776774, 5090905, 5258525, 5276725, 5367635, 5479745, 5536355, 5588855, 5603065, 5718175, 5824285, 6106016, 6277726, 6523256, 6546456, 6780876, 6831386, 6843486, 6844486, 7355537, 8424248, 9051509, 9072709, 9105019, 9313139, 9334339, 9343439, 9435349, 9563659, 9793979, 9814189, 9838389, 9940499, 10711701, 11122111, 11600611, 11922911, 12888821, 13922931, 15822851, 16399361, 16755761, 16955961, 17488471, 18244281, 18422481, 18699681, 26744762, 32344323, 32611623, 34277243, 37533573, 40211204, 41577514, 43699634, 44366344, 45555554, 45755754, 46433464, 47622674, 49066094, 50244205, 51488415, 52155125, 52344325, 52722725, 53166135, 53211235, 53933935, 55344355, 56722765, 56800865, 57488475, 58366385, 62988926, 63844836, 63866836, 64633646, 66999966, 67233276, 68688686, 69388396, 69722796, 69933996, 72299227, 92800829, 95177159, 95544559, 97299279]
    return bisect.bisect_right(vals, n)

You get into the realm of a few hundred nanoseconds...

Some mistakes in your code:
1 - You need to use p <= n instead of p < n ( try n = 77 )
2 - Use set instead of list to store your answer, that will speed up your solution.
3 - You don't need to call cumsum in the range [i, limit) you can accumulate until the sum gets greater than n . This will speed up your solution too.
4 - Try to be less "pythonic". Do not fill your code of long and ugly list comprehensions. (This is not a mistake exactly)

This is your code after some changes:

def is_palindrome(s):
    return s == s[::-1]

def values(n):
    squares = [0]
    i = 1
    while i * i <= n:
        squares.append(squares[-1] + i * i)
        i += 1
    pals = set()
    for i in range(1, len(squares)):
        j = i + 1
        while j < len(squares) and (squares[j] - squares[i - 1]) <= n:
            s = squares[j] - squares[i - 1]
            if is_palindrome(str(s)):
                pals.add(s)
            j += 1
    return len(pals)

you can use some itertools tricks to likely speed this up a bit.

import itertools

limit = int(n**0.5)  # as before, this can be improved but non-obviously.
squares = [i**2 for i in range(1, limit+1)]
num_squares = len(squares)  # inline this to save on lookups
seqs = [(squares[i:j] for j in range(i+2, num_squares)) for i in range(num_squares-2)]

seqs is now a list of generators that build your square sequences. eg for n=100 we have:

[ [[1, 4], [1, 4, 9], [1, 4, 9, 16], [1, 4, 9, 16, 25], [1, 4, 9, 16, 25, 36], [1, 4, 9, 16, 25, 36, 49], [1, 4, 9, 16, 25, 36, 49, 64], [1, 4, 9, 16, 25, 36, 49, 64, 81]],
  [[4, 9], [4, 9, 16], [4, 9, 16, 25], [4, 9, 16, 25, 36], [4, 9, 16, 25, 36, 49], [4, 9, 16, 25, 36, 49, 64], [4, 9, 16, 25, 36, 49, 64, 81]],
  [[9, 16], [9, 16, 25], [9, 16, 25, 36], [9, 16, 25, 36, 49], [9, 16, 25, 36, 49, 64], [9, 16, 25, 36, 49, 64, 81]],
  [[16, 25], [16, 25, 36], [16, 25, 36, 49], [16, 25, 36, 49, 64], [16, 25, 36, 49, 64, 81]],
  [[25, 36], [25, 36, 49], [25, 36, 49, 64], [25, 36, 49, 64, 81]],
  [[36, 49], [36, 49, 64], [36, 49, 64, 81]],
  [[49, 64], [49, 64, 81]],
  [[64, 81]],
]

If we map sum over those we can use itertools.takewhile to cut down on the number of equality checks we need to do later:

sums = [itertools.takewhile(lambda s: s <= n, lst) for lst in [map(sum, lst) for lst in seqs]]

This cuts down the resulting list substantially, while tallying the accumulated sums

[ [5, 14, 30, 55, 91],
  [13, 29, 54, 90],
  [25, 50, 86],
  [41, 77],
  [61],
  [85],
  [],
  [],
]

We can cut out those empty lists with filter(None, sums) , then chain together with itertools.chain.from_iterable and pass into is_palindrome .

def is_palindrome(number):
    s = str(number)
    return s == s[::-1]

result = [k for k in itertools.chain.from_iterable(filter(None, sums)) if is_palindrome(k)]

We could do our perfect square check here, too, but we already know that any perfect square must be in squares . For arbitrarily large n , it becomes cheaper and cheaper to build these both into sets and use set.difference .

result = {k for k in ...}  # otherwise the same, just use curly braces
                           # to invoke set comprehension instead of list
squareset = set(squares)
final = result.difference(squareset)
# equivalent to `result - squareset`

A lot of those sites use problems that are relatively easy to solve programatically, but hard to do efficiently. This will be a common problem you run into.

As for a solution, first try to come up with an efficient algorithm. Then, if it doesn't satisfy the time constraint, work on implementing finding less computationally expensing python standard library methods that achieve the same thing. For example, does pals.extend() traverse the entire list every time, or is there a pointer to the last element? If it traverses, then look for a method that doesn't (pals.append() might do this, but I'm not sure

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM