简体   繁体   中英

How can I make this Scala function (a “flatMap” variant) tail recursive?

I'm having a look at the following code

http://aperiodic.net/phil/scala/s-99/p26.scala

Specifically

def flatMapSublists[A,B](ls: List[A])(f: (List[A]) => List[B]): List[B] = 
    ls match {
      case Nil => Nil
      case sublist@(_ :: tail) => f(sublist) ::: flatMapSublists(tail)(f)
    }

I'm getting a StackOverflowError for large values presumably because the function is not tail recursive. Is there a way to transform the function to accommodate large numbers?

It is definitely not tail recursive. The f(sublist) ::: is modifying the results of the recursive call, making it a plain-old-stack-blowing recursion instead of a tail recursion.

One way to ensure that your functions are tail recursive is to put the @annotation.tailrec on any function that you expect to be tail recursive. The compiler will report an error if it fails to perform the tail call optimization.

For this, I would add a small helper function that's actually tail recursive:

def flatMapSublistsTR[A,B](ls: List[A])(f: (List[A]) => List[B]): List[B] = {
  @annotation.tailrec
  def helper(r: List[B], ls: List[A]): List[B] = {
    ls match {
      case Nil => r
      case sublist@(_ :: tail) => helper(r ::: f(sublist), tail)
    }
  }
  helper(Nil, ls)
}

For reasons not immediately obvious to me, the results come out in a different order than the original function. But, it looks like it works :-) Fixed.

Here is another way to implement the function:

scala> def flatMapSublists[A,B](ls: List[A])(f: (List[A]) => List[B]): List[B] =
     |   List.iterate(ls, ls.size)(_.tail).flatMap(f)
flatMapSublists: [A, B](ls: List[A])(f: List[A] => List[B])List[B]

A simply comparison between dave's flatMapSublistsTR and mine:

scala> def time(count: Int)(call : => Unit):Long = {
     |    val start = System.currentTimeMillis
     |    var cnt =  count
     |    while(cnt > 0) {
     |       cnt -= 1
     |       call
     |    }
     |    System.currentTimeMillis - start
     | }
time: (count: Int)(call: => Unit)Long

scala> val xs = List.range(0,100)

scala> val fn = identity[List[Int]] _
fn: List[Int] => List[Int] = <function1>

scala> time(10000){ flatMapSublists(xs)(fn) }
res1: Long = 5732

scala> time(10000){ flatMapSublistsTR(xs)(fn) }
res2: Long = 347232

Where the method flatMapSublistsTR is implemented as:

def flatMapSublistsTR[A,B](ls: List[A])(f: (List[A]) => List[B]): List[B] = {
  @annotation.tailrec
  def helper(r: List[B], ls: List[A]): List[B] = {
    ls match {
      case Nil => r
      case sublist@(_ :: tail) => helper(r ::: f(sublist), tail)
    }
  }
  helper(Nil, ls)
}
def flatMapSublists2[A,B](ls: List[A], result: List[B] = Nil)(f: (List[A]) => List[B]): List[B] = 
    ls match {
      case Nil => result
      case sublist@(_ :: tail) => flatMapSublists2(tail, result ++ f(sublist))(f)
    }

You generally just need to add a result result parameter to carry from one iteration to the next, and spit out the result at the end instead of adding the end to the list.

Also that confusting sublist@ thing can be simplified to

case _ :: tail => flatMapSublists2(tail, result ++ f(ls))(f)

Off-topic: here's how I solved problem 26, without the need for helper methods like the one above. If you can make this tail-recursive, have a gold star.

  def combinations[A](n: Int, lst: List[A]): List[List[A]] = n match {
    case 1 => lst.map(List(_))
    case _ => lst.flatMap(i => combinations (n - 1, lst.dropWhile(_ != i).tail) map (i :: _))
  }

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM