简体   繁体   中英

fast power function in scala

I tried to write a function for fast power in scala, but I keep getting java.lang.StackOverflowError. I think it has something to do with two slashes that use in the third line when I recursively called this function for n/2. Can someone explain why is this happening

def fast_power(x:Double, n:Int):Double = {
if(n % 2 == 0 && n > 1)
        fast_power(x, n/2) * fast_power(x, n /2)
else if(n % 2 == 1 && n > 1)
        x * fast_power(x, n - 1)
else if(n == 0) 1
else  1 / fast_power(x, n)
}

Your code doesn't terminate, because there was no case for n = 1 . Moreover, your fast_power has linear runtime.

If you write it down like this instead:

def fast_power(x:Double, n:Int):Double = {
  if(n < 0) {
    1 / fast_power(x, -n)
  } else if (n == 0) {
    1.0
  } else if (n == 1) {
    x
  } else if (n % 2 == 0) {
    val s = fast_power(x, n / 2)
    s * s
  } else {
    val s = fast_power(x, n / 2)
    x * s * s
  }
}

then it is immediately obvious that the runtime is logarithmic, because n is at least halved in every recursive invocation.

I don't have any strong opinions on if -vs- match , so I just sorted all the cases in ascending order.

Prefer the match construct instead of multiple if/else blocks. This will help you isolate the problem you have (wrong recursive call), and write more understandable recursive functions. Always put the termination conditions first.

def fastPower(x:Double, m:Int):Double = m match {
  case 0 => 1
  case 1 => x
  case n if n%2 == 0 => fastPower(x, n/2) * fastPower(x, n/2)
  case n => x * fastPower(x, n - 1)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM