简体   繁体   中英

Recursion in Lua vs. Java

I have a recursive algorithm (in three different forms) which computes the sum of the first n odd positive integers. These algorithms are for learning only, I'm aware that there are better ways of solving this problem. I've written the three variants in both Lua and Java, which I'll call Lua 1, Lua 2, Lua 3 and Java 1, Java 2 and Java 3. The first two are very similar, just rearranged.

The Lua programs are here , and the Java programs are here .

Lua 1 and 2 perform extremely well and could easily reach n = 100,000,000. Lua 3 hits a stack overflow where n > 16,000.

Java 1 could only reach n = 4000 before hitting a stack overflow while Java 2 reached 9000. Java 3 managed to get to n = 15,000 before again hitting a stack overflow.

Can anyone explain these results? Why did Java 1, 2 and 3 perform so poorly while Lua 1 and 2 performed so well?

Lua does tail-call elimination . For example, a function like this:

function foo (n)
    if n > 0 then return foo(n - 1) end
end

will never cause stack overflow no matter what value of n you are calling.

In Lua, only a call with the form return func(args) is a tail call, like what your first two Lua programs do. But in the third Lua program:

return (sumOdds3(n-1)) + (2*n - 1)

Lua still need to do the calculation before returning, so there's no proper tail call.

Java is not designed for recursive algorithms. In particular it doesn't support common optimisations like tail-call optimisation.

Java is better suited to use loops and generally is much faster, and often simple to use a plain loop.

If you use iteration and a Deque, you should fine there is almost no limit to the value of n

When you are running very inefficient code, you tend to find that whether one case or another can optimise way the inefficiency can make a big difference.

One way to do it more efficiently

// function to compute the sum of the first n odd positive integers
public static long sumOdds(long n) {
    long sumAll = n * (n + 1)/2;
    long sumEven = n/2 * (n/2 + 1);
    return sumAll - sumEven;
}
public static void main(String[] args) throws Exception {
    sumOdds(1);

    long start = System.nanoTime();
    long l = sumOdds(Integer.MAX_VALUE);
    long time = System.nanoTime() - start;
    System.out.printf("sumOdds(%,d) took %,d ns%n", Integer.MAX_VALUE, time);
}

prints

sumOdds(2,147,483,647) took 343 ns

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM