简体   繁体   English

Java中的随机加权选择

[英]Random weighted selection in Java

I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight我想从一组中随机选择一个项目,但选择任何项目的机会应该与关联的权重成正比

Example inputs:示例输入:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.所以,如果我有 4 种可能的物品,那么得到任何一件没有重量的物品的几率是四分之一。

In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.在这种情况下,用户获得苦难之剑的可能性应该是三刃剑的 10 倍。

How do I make a weighted random selection in Java?如何在 Java 中进行加权随机选择?

I would use a NavigableMap我会使用 NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively假设我有一个概率分别为 40%、35%、25% 的动物狗、猫、马的列表

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 

There is now a class for this in Apache Commons: EnumeratedDistribution现在在 Apache Commons 中有一个类: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();

where itemWeights is a List<Pair<Item, Double>> , like (assuming Item interface in Arne's answer):其中itemWeights是一个List<Pair<Item, Double>> ,例如(假设 Arne 的回答中的Item接口):

final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

or in Java 8:或在 Java 8 中:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Note: Pair here needs to be org.apache.commons.math3.util.Pair , not org.apache.commons.lang3.tuple.Pair .注意:这里的Pair需要是org.apache.commons.math3.util.Pair ,而不是org.apache.commons.lang3.tuple.Pair

You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function.您不会找到解决此类问题的框架,因为请求的功能只不过是一个简单的函数。 Do something like this:做这样的事情:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}

Use an alias method使用别名方法

If you're gonna roll a lot of times (as in a game), you should use an alias method.如果您要多次滚动(如在游戏中),则应使用别名方法。

The code below is rather long implementation of such an alias method, indeed.下面的代码确实是这种别名方法的相当长的实现。 But this is because of the initialization part.但这是因为初始化部分。 The retrieval of elements is very fast (see the next and the applyAsInt methods they don't loop).元素的检索非常快(请参阅next和它们不循环的applyAsInt方法)。

Usage用法

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Implementation执行

This implementation:这个实现:

  • uses Java 8 ;使用Java 8
  • is designed to be as fast as possible (well, at least, I tried to do so using micro-benchmarking);被设计得尽可能快(好吧,至少,我尝试使用微基准测试来做到这一点);
  • is totally thread-safe (keep one Random in each thread for maximum performance, use ThreadLocalRandom ?);是完全线程安全的(在每个线程中保留一个Random以获得最大性能,使用ThreadLocalRandom ?);
  • fetches elements in O(1) , unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));在 O(1) 中获取元素,这与您通常在互联网或 StackOverflow 上找到的不同,在那里,幼稚的实现以 O(n) 或 O(log(n)) 运行;
  • keeps the items independant from their weight , so an item can be assigned various weights in different contexts.使项目与其重量无关,因此可以在不同的上下文中为项目分配各种权重。

Anyways, here's the code.无论如何,这是代码。 (Note that I maintain an up to date version of this class .) (请注意, 我维护了此类的最新版本。)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}

139 139

There is a straightforward algorithm for picking an item at random, where items have individual weights:有一个简单的算法可以随机选择一个项目,其中项目具有单独的权重:

  1. calculate the sum of all the weights计算所有权重的总和

  2. pick a random number that is 0 or greater and is less than the sum of the weights选择一个大于等于 0 且小于权重之和的随机数

  3. go through the items one at a time, subtracting their weight from your random number until you get the item where the random number is less than that item's weight一次检查一件物品,从你的随机数中减去它们的重量,直到你得到随机数小于该物品重量的物品

public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}

If you need to remove elements after choosing you can use another solution.如果您需要在选择后删除元素,您可以使用其他解决方案。 Add all the elements into a 'LinkedList', each element must be added as many times as it weight is, then use Collections.shuffle() which, according to JavaDoc将所有元素添加到“LinkedList”中,每个元素必须添加与其权重相同的次数,然后使用Collections.shuffle() ,根据JavaDoc

Randomly permutes the specified list using a default source of randomness.使用默认的随机源随机排列指定的列表。 All permutations occur with approximately equal likelihood.所有排列都以近似相等的可能性发生。

Finally, get and remove elements using pop() or removeFirst()最后,使用pop()removeFirst()获取和删除元素

Map<String, Integer> map = new HashMap<String, Integer>() {{
    put("Five", 5);
    put("Four", 4);
    put("Three", 3);
    put("Two", 2);
    put("One", 1);
}};

LinkedList<String> list = new LinkedList<>();

for (Map.Entry<String, Integer> entry : map.entrySet()) {
    for (int i = 0; i < entry.getValue(); i++) {
        list.add(entry.getKey());
    }
}

Collections.shuffle(list);

int size = list.size();
for (int i = 0; i < size; i++) {
    System.out.println(list.pop());
}

A simple (even naive?), but (as I believe) straightforward method:一个简单的(甚至是幼稚的?),但(我相信)直截了当的方法:

/**
* Draws an integer between a given range (excluding the upper limit).
* <p>
* Simulates Python's randint method.
* 
* @param min: the smallest value to be drawed.
* @param max: the biggest value to be drawed.
* @return The value drawn.
*/
public static int randomInt(int min, int max)
    {return (int) (min + Math.random()*max);}

/**
 * Tests wether a given matrix has all its inner vectors
 * has the same passed and expected lenght.
 * @param matrix: the matrix from which the vectors length will be measured.
 * @param expectedLenght: the length each vector should have.
 * @return false if at least one vector has a different length.
 */
public static boolean haveAllVectorsEqualLength(int[][] matrix, int expectedLenght){
    for(int[] vector: matrix){if (vector.length != expectedLenght) {return false;}}
    return true;
}

/**
* Draws an integer between a given range
* by weighted values.
* 
* @param ticketBlock: matrix with limits and weights for the drawing. All its
* vectors should have lenght two. The weights, instead of percentages, should be
* measured as integers, according to how rare each one should be draw, the rarest
* receiving the smallest value.
* @return The value drawn.
*/
public static int weightedRandomInt(int[][] ticketBlock) throws RuntimeException {
    boolean theVectorsHaventAllLengthTwo = !(haveAllVectorsEqualLength(ticketBlock, 2));
    if (theVectorsHaventAllLengthTwo)
        {throw new RuntimeException("The given matrix has, at least, one vector with length lower or higher than two.");}
    // Need to test for duplicates or null values in ticketBlock!
    
    // Raffle urn building:
    int raffleUrnSize = 0, urnIndex = 0, blockIndex = 0, repetitionCount = 0;
    for(int[] ticket: ticketBlock){raffleUrnSize += ticket[1];}
    int[] raffleUrn = new int[raffleUrnSize];
    
    // Raffle urn filling:
    while (urnIndex < raffleUrn.length){
        do {
            raffleUrn[urnIndex] = ticketBlock[blockIndex][0];
            urnIndex++; repetitionCount++;
        } while (repetitionCount < ticketBlock[blockIndex][1]);
        repetitionCount = 0; blockIndex++;
    }
    
    return raffleUrn[randomInt(0, raffleUrn.length)];
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM