简体   繁体   中英

Efficiently randomly sampling List while maintaining order

I would like to take random samples from very large lists while maintaining the order. I wrote the script below, but it requires .map(idx => ls(idx)) which is very wasteful. I can see a way of making this more efficient with a helper function and tail recursion, but I feel that there must be a simpler solution that I'm missing.

Is there a clean and more efficient way of doing this?

import scala.util.Random

def sampledList[T](ls: List[T], sampleSize: Int) = {
  Random
    .shuffle(ls.indices.toList)
    .take(sampleSize)
    .sorted
    .map(idx => ls(idx))
}

val sampleList = List("t","h","e"," ","q","u","i","c","k"," ","b","r","o","w","n")
// imagine the list is much longer though

sampledList(sampleList, 5) // List(e, u, i, r, n)

EDIT: It appears I was unclear: I am referring to maintaining the order of the values, not the original List collection.

If by

maintaining the order of the values

you understand to keeping the elements in the sample in the same order as in the ls list, then with a small modification to your original solution the performances can be greatly improved:

import scala.util.Random

def sampledList[T](ls: List[T], sampleSize: Int) = {
  Random.shuffle(ls.zipWithIndex).take(sampleSize).sortBy(_._2).map(_._1)
}

This solution has a complexity of O(n + k*log(k)), where n is the list's size, and k is the sample size, while your solution is O(n + k * log(k) + n*k).

Here is an (more complex) alternative that has O(n) complexity. You can't get any better in terms of complexity (though you could get better performance by using another collection, in particular a collection that has a constant time size implementation). I did a quick benchmark which indicated that the speedup is very substantial.

import scala.util.Random
import scala.annotation.tailrec

def sampledList[T](ls: List[T], sampleSize: Int) = {
  @tailrec
  def rec(list: List[T], listSize: Int, sample: List[T], sampleSize: Int): List[T] = {
    require(listSize >= sampleSize, 
      s"listSize must be >= sampleSize, but got listSize=$listSize and sampleSize=$sampleSize"
    )
    list match {
      case hd :: tl => 
        if (Random.nextInt(listSize) < sampleSize)
          rec(tl, listSize-1, hd :: sample, sampleSize-1)
        else rec(tl, listSize-1, sample, sampleSize)
      case Nil =>
        require(sampleSize == 0, // Should never happen
          s"sampleSize must be zero at the end of processing, but got $sampleSize"
        )
        sample
    }
  }
  rec(ls, ls.size, Nil, sampleSize).reverse
}

The above implementation simply iterates over the list and keeps (or not) the current element according to a probability which is designed to give the same chance to each element. My logic may have a flow, but at first blush it seems sound to me.

Here's another O(n) implementation that should have a uniform probability for each element:

  implicit class SampleSeqOps[T](s: Seq[T]) {
    def sample(n: Int, r: Random = Random): Seq[T] = {
      assert(n >= 0 && n <= s.length)

      val res = ListBuffer[T]()

      val length = s.length
      var samplesNeeded = n

      for { (e, i) <- s.zipWithIndex } {
        val p = samplesNeeded.toDouble / (length - i)

        if (p >= r.nextDouble()) {
          res += e
          samplesNeeded -= 1
        }
      }

      res.toSeq
    }
  }

I'm using it frequently with collections > 100'000 elements and the performance seems reasonable.

It's probably the same idea as in Régis Jean-Gilles's answer but I think the imperative solution is slightly more readable in this case.

Perhaps I don't quite understand, but since Lists are immutable you don't really need to worry about 'maintaining the order' since the original List is never touched. Wouldn't the following suffice?

def sampledList[T](ls: List[T], sampleSize: Int) =
  Random.shuffle(ls).take(sampleSize)

While my previous answer has linear complexity, it does have the drawback of requiring two passes, the first one corresponding to the need to compute the length before doing anything else. Besides affecting the running time, we might want to sample a very large collection for which it is not practical nor efficient to load the whole collection in memory at once, in which case we'd like to be able to work with a simple iterator. As it happens, we don't need to invent anything to fix this. There is simple and clever algorithm called reservoir sampling which does exactly this (building a sample as we iterate over a collection, all in one pass). With a minor modification we can also preserve the order, as required:

import scala.util.Random
def sampledList[T](ls: TraversableOnce[T], sampleSize: Int, preserveOrder: Boolean = false, rng: Random = new Random): Iterable[T] = {  
  val result = collection.mutable.Buffer.empty[(T, Int)]
  for ((item, n) <- ls.toIterator.zipWithIndex) {
    if (n < sampleSize) result += (item -> n)
    else {
      val s = rng.nextInt(n)
      if (s < sampleSize) {
        result(s) = (item -> n)
      }
    }
  }
  if (preserveOrder) {
    result.sortBy(_._2).map(_._1)
  } 
  else result.map(_._1)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM