简体   繁体   English

快速功能合并排序

[英]Fast functional merge sort

Here's my implementation of merge sort in Scala: 这是我在Scala中实现合并排序的实现:

object FuncSort {
  def merge(l: Stream[Int], r: Stream[Int]) : Stream[Int] = {
    (l, r) match {
      case (h #:: t, Empty) => l
      case (Empty, h #:: t) => r
      case (x #:: xs, y #:: ys) => if(x < y ) x #:: merge(xs, r) else y #:: merge(l, ys)
    }
  }

  def sort(xs: Stream[Int]) : Stream[Int] = {
    if(xs.length == 1) xs
    else {
      val m = xs.length / 2
      val (l, r) = xs.splitAt(m)
      merge(sort(l), sort(r))
    }
  }
}

It works correctly and it seems that asymptotically it is fine as well but it is way slower (approx 10 times) than Java implementation from here http://algs4.cs.princeton.edu/22mergesort/Merge.java.html and uses a lot of memory. 它可以正常工作,并且似乎在渐近上也可以,但是它比Java实现慢得多(大约10倍),可以从此处http://algs4.cs.princeton.edu/22mergesort/Merge.java.html进行,并使用很多内存。 Is there a faster implementation of merge sort which is functional ? 是否有更快的合并排序功能实现 Obviously, it's possible to port Java version line by line but that's not what I'm looking for. 显然,可以逐行移植Java版本,但这不是我想要的。

UPD: I've changed Stream to List and #:: to :: and the sorting routine became faster, only three to four times slower than Java version. UPD:我将Stream更改为List ,将#::更改为:: ,并且排序例程变得更快,仅比Java版本慢三到四倍。 But I don't understand why doesn't it crashes with stack overflow? 但是我不明白为什么它不会因堆栈溢出而崩溃? merge isn't tail-recursive, all arguments are strictly evaluated...how is it possible? merge不是尾递归的,所有参数都经过严格评估...这怎么可能?

You have raised multiple questions. 您提出了多个问题。 I try to answer them in a logical order: 我尝试按照逻辑顺序回答它们:

No stack overflow in the Stream version 流版本中没有堆栈溢出

You did not really ask this one, but it leads to some interesting observations. 您并没有真正问过这个问题,但是它导致了一些有趣的观察。

In the Stream version you are using #:: merge(...) inside the merge function. 在Stream版本中,您在merge功能内使用#:: merge(...) Usually this would be a recursive call and might lead to a stack overflow for big enough input data. 通常,这将是递归调用,并可能导致堆栈溢出,无法容纳足够大的输入数据。 But not in this case. 但在这种情况下不是。 The operator #::(a,b) is implemented in class ConsWrapper[A] (there is an implicit conversion) and is a synonym for cons.apply[A](hd: A, tl: ⇒ Stream[A]): Cons[A] . 运算符#::(a,b)class ConsWrapper[A]实现(存在隐式转换),并且是cons.apply[A](hd: A, tl: ⇒ Stream[A]): Cons[A] )的同义词cons.apply[A](hd: A, tl: ⇒ Stream[A]): Cons[A] As you can see, the second argument is call by name, meaning it is evaluated lazily. 如您所见,第二个参数是按名称调用的,这意味着它是惰性计算的。

That means merge returns a newly created object of type cons which will eventually call merge again. 这意味着merge返回一个新创建的cons类型的对象,该对象最终将再次调用merge。 In other words: The recursion does not happen on the stack, but on the heap. 换句话说:递归不是在堆栈上发生,而是在堆上发生。 And usually you have plenty of heap. 通常,您有很多堆。

Using the heap for recursion is a nice technique to handle very deep recursions. 使用堆进行递归是一种很好的技术,可以处理非常深的递归。 But it is much slower than using the stack. 但这比使用堆栈要慢得多。 So you traded speed for recursion depth. 因此,您将速度与递归深度进行了交换。 This is the main reason, why using Stream is so slow. 这是主要原因,为什么使用Stream太慢。

The second reason is, that for getting the length of the Stream , Scala has to materialize the whole Stream . 第二个原因是,为了获得Stream的长度,Scala必须实现整个Stream But during sorting the Stream it would have to materialize each element anyway, so this does not hurt very much. 但是在对Stream进行排序时,无论如何都必须实现每个元素,因此这不会造成太大的伤害。

No stack overflow in List version 列表版本中没有堆栈溢出

When you are changing Stream for List, you are indeed using the stack for recursion. 当更改Stream for List时,实际上是在使用堆栈进行递归。 Now a Stack overflow could happen. 现在可能会发生堆栈溢出。 But with sorting, you usually have a recursion depth of log(size) , usually the logarithm of base 2 . 但是通过排序,您通常具有log(size)的递归深度,通常是以2为底的对数。 So to sort 4 billion input items, you would need a about 32 stack frames. 因此,要对40亿个输入项目进行分类,您将需要大约32个堆栈帧。 With a default stack size of at least 320k (on Windows, other systems have larger defaults), this leaves place for a lot of recursions and hence for lots of input data to be sorted. 默认堆栈大小至少为320k(在Windows上,其他系统具有更大的默认值),因此有很多递归空间,因此可以对许多输入数据进行排序。

Faster functional implementation 更快的功能实现

It depends :-) 这取决于 :-)

You should use the stack, and not the heap for recursion. 您应该使用堆栈,而不要使用堆进行递归。 And you should decide your strategy depending on the input data: 您应该根据输入数据决定策略:

  1. For small data blocks, sort them in place with some straight forward algorithm. 对于较小的数据块,请使用一些简单的算法对其进行排序。 The algorithmic complexity won't bite you, and you can gain a lot of performance from having all data in cache. 算法的复杂性不会给您带来麻烦,并且通过将所有数据都缓存在缓存中可以提高性能。 Of course, ou could still hand code sorting networks for the given sizes. 当然,对于给定的大小,您仍然可以手动编写代码分类网络
  2. If you have numeric input data, you can use radix sort and handle the work to the vector units on you processor or your GPU (more sophisticated algorithms can be found in GPU Gems ). 如果您具有数字输入数据,则可以使用基数排序并按处理器或GPU上的矢量单位处理功(可以在GPU Gems中找到更复杂的算法)。
  3. For medium sized data blocks, you can use a divide-and-conquer strategy to split the data to multiple threads (only if you have multiple cores!) 对于中型数据块,您可以使用分而治之的策略将数据拆分为多个线程(仅当您具有多个内核时!)
  4. For huge data blocks use merge sort and split it of in blocks that fit in memory. 对于巨大的数据块,请使用归并排序并将其拆分为适合内存的块。 If you want, you can distribute these blocks on the network and sort in memory. 如果需要,可以在网络上分发这些块并在内存中排序。

Don't use swap and use your caches. 不要使用swap并使用缓存。 Use mutable data structures if you can and sort in place. 如果可以,请使用可变数据结构并进行适当排序。 I think that functional and fast sorting does not work very well together. 我认为功能排序和快速排序不能很好地协同工作。 To make sorting really fast, you will have to use stateful operations (eg in-place mergesort on mutable arrays). 为了使排序真正快速,您将必须使用有状态操作(例如,对可变数组进行就地归并排序)。

I usually try this on all my programs: Use pure functional style as far as possible but use stateful operations for small parts when feasible (eg because it has better performance or the code just has to deal with lots of states and becomes much better readable when I use var s instead of val s). 我通常在所有程序上都尝试这样做:尽可能使用纯函数样式,但在可行的情况下对小部分使用有状态操作(例如,因为它具有更好的性能,或者代码只需要处理很多状态,并且在处理时会变得更好可读)我使用var而不是val )。

There are a couple of things to note here. 这里有几件事要注意。

First, you don't properly account for the case of your initial stream to sort being empty. 首先,您没有正确考虑初始流排序为空的情况。 You can fix this by modifying the initial check inside sort to read if(xs.length <= 1) xs . 您可以通过修改sort内部的初始检查以读取if(xs.length <= 1) xs来解决此问题。

Second, streams can have uncalculable lengths (eg. Strem.from(1) ), which poses a problem when trying to calculate half of that (potentially infinite) length - you might want to consider putting a check for that using hasDefiniteSize or similar (although used naively this could filter out some otherwise calculable streams). 其次,流的长度可能无法计算(例如Strem.from(1) ),这在尝试计算该长度(可能是无限长)的一半时会出现问题-您可能需要考虑使用hasDefiniteSize或类似方法检查该hasDefiniteSize (尽管天真地使用了它,但可以过滤掉一些原本可以计算的流)。

Finally, the fact that this is defined to operate on streams may be what is slowing it down. 最后,将其定义为对流进行操作的事实可能会使它变慢。 I tried timing a large number of runs of your stream version of mergesort versus a version written to process lists, and the list version came out approximately 3 times faster (admittedly only on a single pair of runs). 我尝试对您的mergesort流版本的大量运行与写入进程列表的版本进行定时运行,并且列表版本的发布速度大约快了3倍(只能在一对运行中运行)。 This suggests that streams are less efficient to work with in this manner than lists or other sequence types (Vector might be faster still, or using arrays as per the Java solution referenced). 这表明,与列表或其他序列类型相比,以这种方式使用流的效率较低(Vector可能仍然更快,或者按照引用的Java解决方案使用数组)。

That said, I'm not a great expert on timings and efficiencies, so someone else may be able to give a more knowledgable response. 就是说,我不是时间和效率方面的优秀专家,因此其他人也许可以给出更容易理解的答复。

Your implementation is a top-down merge sort. 您的实现是自上而下的合并排序。 I find that a bottom-up merge sort is faster, and comparable with List.sorted (for my test cases, randomly sized lists of random numbers). 我发现自下而上的合并排序速度更快,并且可以与List.sorted相媲美(对于我的测试案例,是随机大小的随机数列表)。

def bottomUpMergeSort[A](la: List[A])(implicit ord: Ordering[A]): List[A] = {
  val l = la.length

  @scala.annotation.tailrec
  def merge(l: List[A], r: List[A], acc: List[A] = Nil): List[A] = (l, r) match {
    case (Nil, Nil)           => acc
    case (Nil, h :: t)        => merge(Nil, t, h :: acc)
    case (h :: t, Nil)        => merge(t, Nil, h :: acc)
    case (lh :: lt, rh :: rt) =>
      if(ord.lt(lh, rh)) merge(lt, r, lh :: acc)
      else               merge(l, rt, rh :: acc)
  }

  @scala.annotation.tailrec
  def process(la: List[A], h: Int, acc: List[A] = Nil): List[A] = {
    if(la == Nil) acc.reverse
    else {
      val (l1, r1) = la.splitAt(h)
      val (l2, r2) = r1.splitAt(h)

      process(r2, h, merge(l1, l2, acc))
    }
  }

  @scala.annotation.tailrec
  def run(la: List[A], h: Int): List[A] =
    if(h >= l) la
    else       run(process(la, h), h * 2)

  run(la, 1)
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM