简体   繁体   中英

Strange recursion optimization by java

I came across some weird results when I was trying to answer this question: How to improve the performance of the recursive method?

But you don't need to read that post. I'll give the relevant context here. This might seem lengthy but really isn't that complicated if you read through once. I hope it will be interesting for all. For context,

syra(n) = { 1 if n=1; 
            n + syra(n/2) if n is even; and 
            n + syra(3n+1) if n is odd
          }

and

syralen(n) = No. of steps to calculate syra (n)

eg, syralen(1)=1 , syralen(2)=2 since we need to go two steps. syra(10) = 10 + syra(5) = 15 + syra(16) = 31 + syra(8) = 39 + syra(4) = 43 + syra(2) = 45 + syra(1) = 46 . So syra(10) needed 7 steps. Therefore syralen(10)=7

And finally,

lengths(n) = syralen(1)+syralen(2)+...+syralen(n)

The question poster there is trying to calculate lengths(n)

My question is about the recursive solution the Op has posted (which is the second snippet in that question). I'll repost it here:

public class SyraLengths{

        int total=1;
        public int syraLength(long n) {
            if (n < 1)
                throw new IllegalArgumentException();
            if (n == 1) {
                int temp=total;
                total=1;
                return temp;
            }
            else if (n % 2 == 0) {
                total++;
                return syraLength(n / 2);
            }
            else {
                total++;
                return syraLength(n * 3 + 1);
            }
        }

        public int lengths(int n){
            if(n<1){
                throw new IllegalArgumentException();
            }
            int total=0;
            for(int i=1;i<=n;i++){
                total+=syraLength(i);
            }

            return total;
        }

        public static void main(String[] args){
            System.out.println(new SyraLengths().lengths(5000000));
        }
       }

Surely an unusual (and probably not recommended) way of doing it recursively, but it does calculate the right thing, I have verified that. I tried to write a more usual recursive version of the same:

public class SyraSlow {

    public long lengths(int n) {
        long total = 0;
        for (int i = 1; i <= n; ++i) {
            total += syraLen(i);
        }
        return total;
    }

    private long syraLen(int i) {
        if (i == 1)
            return 1;
        return 1 + ((i % 2 == 0) ? syraLen(i / 2) : syraLen(i * 3 + 1));
    }

Now here's the weird part - I tried to test the performance of both the above versions like:

public static void main(String[] args){
            long t1=0,t2=0;
            int TEST_VAL=50000;

            t1 = System.currentTimeMillis();
            System.out.println(new SyraLengths().lengths(TEST_VAL));
            t2 = System.currentTimeMillis();
            System.out.println("SyraLengths time taken: " + (t2-t1));

            t1 = System.currentTimeMillis();
            System.out.println(new SyraSlow().lengths(TEST_VAL));
            t2 = System.currentTimeMillis();
            System.out.println("SyraSlow time taken: " + (t2-t1));
        }

For TEST_VAL=50000 , the output is:

5075114
SyraLengths time taken: 44
5075114
SyraSlow time taken: 31

As expected (I guess) the plain recursion is slightly better. But when I go one step further and use TEST_VAL=500000 , the output is:

62634795
SyraLengths time taken: 378
Exception in thread "main" java.lang.StackOverflowError
    at SyraSlow.syraLen(SyraSlow.java:15)
    at SyraSlow.syraLen(SyraSlow.java:15)
    at SyraSlow.syraLen(SyraSlow.java:15)

Why is it? What kind of optimization is being done by Java here that the SyraLengths version doesn't hit StackOverflow (It works even at TEST_VAL=5000000 )? I even tried using an accumulator-based recursive version just in case my JVM was doing some tail-call optimization:

private long syraLenAcc(int i, long acc) {
        if (i == 1) return acc;
    if(i%2==0) {
        return syraLenAcc(i/2,acc+1);
    }
    return syraLenAcc(i * 3 + 1, acc+1);
    }

But I still got the same result (thus there is no tail-call optimization here). So then, what's happening here?

PS: Please edit to a better title if you can think of any.

With the original version, tail recursion optimization was possible (within the JIT). Dunno whether it actually occurred or not, though. But it's possible that the original is just slightly more efficient with heap [er, I mean stack] use. (Or there may be a functional difference that was not obvious on cursory examination.)

Well, it turns out there is a simple explanation to it:

I am using long syraLen(int n) as the method signature. But the value of n can actually be much larger than Integer.MAX_VALUE . So syraLen gets negative inputs and therein lies the problem. If I change it to long syraLen(long n) , everything works perfectly! I wish I had also put the if(n < 1) throw new IllegalArgumentException(); like the original poster. Would have saved me some time.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM