简体   繁体   中英

Fast functional merge sort

Here's my implementation of merge sort in Scala:

object FuncSort {
  def merge(l: Stream[Int], r: Stream[Int]) : Stream[Int] = {
    (l, r) match {
      case (h #:: t, Empty) => l
      case (Empty, h #:: t) => r
      case (x #:: xs, y #:: ys) => if(x < y ) x #:: merge(xs, r) else y #:: merge(l, ys)
    }
  }

  def sort(xs: Stream[Int]) : Stream[Int] = {
    if(xs.length == 1) xs
    else {
      val m = xs.length / 2
      val (l, r) = xs.splitAt(m)
      merge(sort(l), sort(r))
    }
  }
}

It works correctly and it seems that asymptotically it is fine as well but it is way slower (approx 10 times) than Java implementation from here http://algs4.cs.princeton.edu/22mergesort/Merge.java.html and uses a lot of memory. Is there a faster implementation of merge sort which is functional ? Obviously, it's possible to port Java version line by line but that's not what I'm looking for.

UPD: I've changed Stream to List and #:: to :: and the sorting routine became faster, only three to four times slower than Java version. But I don't understand why doesn't it crashes with stack overflow? merge isn't tail-recursive, all arguments are strictly evaluated...how is it possible?

You have raised multiple questions. I try to answer them in a logical order:

No stack overflow in the Stream version

You did not really ask this one, but it leads to some interesting observations.

In the Stream version you are using #:: merge(...) inside the merge function. Usually this would be a recursive call and might lead to a stack overflow for big enough input data. But not in this case. The operator #::(a,b) is implemented in class ConsWrapper[A] (there is an implicit conversion) and is a synonym for cons.apply[A](hd: A, tl: ⇒ Stream[A]): Cons[A] . As you can see, the second argument is call by name, meaning it is evaluated lazily.

That means merge returns a newly created object of type cons which will eventually call merge again. In other words: The recursion does not happen on the stack, but on the heap. And usually you have plenty of heap.

Using the heap for recursion is a nice technique to handle very deep recursions. But it is much slower than using the stack. So you traded speed for recursion depth. This is the main reason, why using Stream is so slow.

The second reason is, that for getting the length of the Stream , Scala has to materialize the whole Stream . But during sorting the Stream it would have to materialize each element anyway, so this does not hurt very much.

No stack overflow in List version

When you are changing Stream for List, you are indeed using the stack for recursion. Now a Stack overflow could happen. But with sorting, you usually have a recursion depth of log(size) , usually the logarithm of base 2 . So to sort 4 billion input items, you would need a about 32 stack frames. With a default stack size of at least 320k (on Windows, other systems have larger defaults), this leaves place for a lot of recursions and hence for lots of input data to be sorted.

Faster functional implementation

It depends :-)

You should use the stack, and not the heap for recursion. And you should decide your strategy depending on the input data:

  1. For small data blocks, sort them in place with some straight forward algorithm. The algorithmic complexity won't bite you, and you can gain a lot of performance from having all data in cache. Of course, ou could still hand code sorting networks for the given sizes.
  2. If you have numeric input data, you can use radix sort and handle the work to the vector units on you processor or your GPU (more sophisticated algorithms can be found in GPU Gems ).
  3. For medium sized data blocks, you can use a divide-and-conquer strategy to split the data to multiple threads (only if you have multiple cores!)
  4. For huge data blocks use merge sort and split it of in blocks that fit in memory. If you want, you can distribute these blocks on the network and sort in memory.

Don't use swap and use your caches. Use mutable data structures if you can and sort in place. I think that functional and fast sorting does not work very well together. To make sorting really fast, you will have to use stateful operations (eg in-place mergesort on mutable arrays).

I usually try this on all my programs: Use pure functional style as far as possible but use stateful operations for small parts when feasible (eg because it has better performance or the code just has to deal with lots of states and becomes much better readable when I use var s instead of val s).

There are a couple of things to note here.

First, you don't properly account for the case of your initial stream to sort being empty. You can fix this by modifying the initial check inside sort to read if(xs.length <= 1) xs .

Second, streams can have uncalculable lengths (eg. Strem.from(1) ), which poses a problem when trying to calculate half of that (potentially infinite) length - you might want to consider putting a check for that using hasDefiniteSize or similar (although used naively this could filter out some otherwise calculable streams).

Finally, the fact that this is defined to operate on streams may be what is slowing it down. I tried timing a large number of runs of your stream version of mergesort versus a version written to process lists, and the list version came out approximately 3 times faster (admittedly only on a single pair of runs). This suggests that streams are less efficient to work with in this manner than lists or other sequence types (Vector might be faster still, or using arrays as per the Java solution referenced).

That said, I'm not a great expert on timings and efficiencies, so someone else may be able to give a more knowledgable response.

Your implementation is a top-down merge sort. I find that a bottom-up merge sort is faster, and comparable with List.sorted (for my test cases, randomly sized lists of random numbers).

def bottomUpMergeSort[A](la: List[A])(implicit ord: Ordering[A]): List[A] = {
  val l = la.length

  @scala.annotation.tailrec
  def merge(l: List[A], r: List[A], acc: List[A] = Nil): List[A] = (l, r) match {
    case (Nil, Nil)           => acc
    case (Nil, h :: t)        => merge(Nil, t, h :: acc)
    case (h :: t, Nil)        => merge(t, Nil, h :: acc)
    case (lh :: lt, rh :: rt) =>
      if(ord.lt(lh, rh)) merge(lt, r, lh :: acc)
      else               merge(l, rt, rh :: acc)
  }

  @scala.annotation.tailrec
  def process(la: List[A], h: Int, acc: List[A] = Nil): List[A] = {
    if(la == Nil) acc.reverse
    else {
      val (l1, r1) = la.splitAt(h)
      val (l2, r2) = r1.splitAt(h)

      process(r2, h, merge(l1, l2, acc))
    }
  }

  @scala.annotation.tailrec
  def run(la: List[A], h: Int): List[A] =
    if(h >= l) la
    else       run(process(la, h), h * 2)

  run(la, 1)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM