简体   繁体   中英

What is the scala equivalent of Python's Numpy np.random.choice?(Random weighted selection in scala)

I was looking for Scala's equivalent code or underlying theory for pythons np.random.choice (Numpy as np). I have a similar implementation that uses Python's np.random.choice method to select the random moves from the probability distribution.

Python's code

Input list: ['pooh', 'rabbit', 'piglet', 'Christopher'] and probabilies: [0.5, 0.1, 0.1, 0.3]

I want to select one of the value from the input list given the associated probability of each input element.

The Scala standard library has no equivalent to np.random.choice but it shouldn't be too difficult to build your own, depending on which options/features you want to emulate.

Here, for example, is a way to get an infinite Stream of submitted items, with the probability of any one item weighted relative to the others.

def weightedSelect[T](input :(T,Int)*): Stream[T] = {
  val items  :Seq[T]    = input.flatMap{x => Seq.fill(x._2)(x._1)}
  def output :Stream[T] = util.Random.shuffle(items).toStream #::: output
  output
}

With this each input item is given with a multiplier. So to get an infinite pseudorandom selection of the characters c and v , with c coming up 3/5ths of the time and v coming up 2/5ths of the time:

val cvs = weightedSelect(('c',3),('v',2))

Thus the rough equivalent of the np.random.choice(aa_milne_arr,5,p=[0.5,0.1,0.1,0.3]) example would be:

weightedSelect("pooh"-> 5
              ,"rabbit" -> 1
              ,"piglet" -> 1
              ,"Christopher" -> 3).take(5).toArray

Or perhaps you want a better (less pseudo) random distribution that might be heavily lopsided.

def weightedSelect[T](items :Seq[T], distribution :Seq[Double]) :Stream[T] = {
  assert(items.length == distribution.length)
  assert(math.abs(1.0 - distribution.sum) < 0.001) // must be at least close

  val dsums  :Seq[Double] = distribution.scanLeft(0.0)(_+_).tail
  val distro :Seq[Double] = dsums.init :+ 1.1 // close a possible gap
  Stream.continually(items(distro.indexWhere(_ > util.Random.nextDouble())))
}

The result is still an infinite Stream of the specified elements but the passed-in arguments are a bit different.

val choices :Stream[String] = weightedSelect( List("this"     , "that")
                                           , Array(4998/5000.0, 2/5000.0))

// let's test the distribution
val (choiceA, choiceB) = choices.take(10000).partition(_ == "this")

choiceA.length  //res0: Int = 9995
choiceB.length  //res1: Int = 5  (not bad)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM