简体   繁体   English

生成具有非均匀分布的随机整数数组

[英]Generate an array of random integers with non-uniform distribution

I want to write Java code to produce an array of random integers in the range [1,4]. 我想编写Java代码以产生范围为[1,4]的随机整数数组。 The array's length is N, which is provided at run time. 数组的长度为N,在运行时提供。 The problem is that the range [1,4] is not uniformly distributed: 问题是范围[1,4]分布不均匀:

在此处输入图片说明

It means that if I create arrays with N=100, the number '1' will appear averagely 40 times in an array, number '2' 10 times, and so on. 这意味着,如果我创建N = 100的数组,则数字“ 1”将平均出现在数组中40次,数字“ 2”出现10次,依此类推。

For now I am using this code to generate uniform-distributed random numbers in range [1,4]: 现在,我正在使用此代码来生成范围为[1,4]的均匀分布的随机数:

public static void main(String[] args)
    {
        int N;
        System.out.println();
        System.out.print("Enter an integer number: ");
        N = input.nextInt();
        int[] a = new int[N];
        Random generator = new Random();
        for(int i = 0; i < a.length; i++)
        {
            a[i] = generator.nextInt(4)+1;
        }
    }

How do I implement it with a the non-uniform distribution as shown in the graph above? 如何使用上图所示的非均匀分布来实现它?

Here's a way to do it, starting from your code: 从您的代码开始,这是一种实现方法:

public static void main(String[] args){
    int N;
    System.out.println();
    System.out.print("Enter an integer number: ");
    N = input.nextInt();
    int[] a = new int[N];
    Random generator = new Random();
    for (int i = 0; i < a.length; i++) {
        float n = generator.nextFloat();
        if (n <= 0.4) {
            a[i] = 1;
        } else if (n <= 0.7) {
            a[i] = 3;
        } else if (n <= 0.9) {
            a[i] = 4;
        } else {
            a[i] = 2;
        }
    }
}

UPDATE: at @pjs' suggestion, select numbers in order of desdencing probability so you tend to exit the if block earlier 更新:根据@pjs的建议,以降序排列顺序选择数字,因此您倾向于更早退出if块

Another easy solution is to use nextDouble() which generates a random double in [0,1). 另一个简单的解决方案是使用nextDouble()在[0,1)中生成一个随机双精度数。 If the value is < .4 choose 1, else if it is < (.4 + .2) choose 2, etc, with the last branch always choosing the last choice. 如果值<.4,请选择1,否则,如果<(.4 + .2),请选择2,依此类推,最后一个分支始终选择最后一个选择。 This is easily generalized using a for loop. 使用for循环很容易将其概括。

For a more generic approach, you can populate a NavigableMap with the distribution probability: 对于更通用的方法,可以使用分布概率填充NavigableMap

double[] probs = {0.4, 0.1, 0.2, 0.3};
NavigableMap<Double, Integer> distribution = new TreeMap<Double, Integer>();
for(double p : probs) {
    distribution.put(distribution.isEmpty() ? p : distribution.lastKey() + p, distribution.size() + 1);
}

and later query the map with a uniformly distributed random key in the range [0, 1>: 然后使用均匀分布的[0,1>]范围内的随机密钥查询地图:

Random rnd = new Random();
for(int i=0; i<20; i++) {
    System.out.println(distribution.ceilingEntry(rnd.nextDouble()).getValue());
}

This will populate the map with the following key/value pairs: 这将使用以下键/值对填充地图:

0.4 -> 1
0.5 -> 2
0.7 -> 3
1.0 -> 4

To query the map, you first generate a uniformly distributed double in the range 0 to 1. Querying the map using the ceilingEntry method and passing the random number will return the "mapping associated with the least key greater than or equal to the given key" , so eg passing a value in the range <0.4, 0.5] will return the entry with the mapping 0.5 -> 2 . 要查询地图,首先要生成一个范围为0到1的均匀分布的double。使用ceilingEntry方法查询地图并传递随机数将返回“与大于或等于给定键的最小键关联的映射”。 ,因此,例如,传递范围<0.4,0.5]中的值将返回映射0.5 -> 2的条目。 Using getValue() on the returned map entry will hence return 2. 因此,在返回的地图条目上使用getValue()将返回2。

Let a1, a2, a3 and a4 be doubles that specify the relative probabilities and s = a1+a2+a3+a4 That means the probability for 1 is a1/s , the probability for 2 is a2/s , ... a1, a2, a3a4为指定相对概率的双精度数,并且s = a1+a2+a3+a4这意味着1的概率为a1/s2的概率为a2/s ,...

Then create a random double d using generator.nextDouble() . 然后使用generator.nextDouble()创建一个随机double d。

If 0 <= d < a1/s then the integer should be 1, 如果0 <= d < a1/s则整数应为1

if a1/s <= d < (a1+a2)/s then the integer should be 2 如果a1/s <= d < (a1+a2)/s则整数应为2

if (a1+a2)/s <= d < (a1+a2+a3)/s then the integer should be 3 如果(a1+a2)/s <= d < (a1+a2+a3)/s则整数应为3

if (a1+a2+a3)/s <= d < 1 then the integer should be 4 如果(a1+a2+a3)/s <= d < 1则整数应为4

a slightly more extensible version of Miquel's (and also what Teresa suggested): Miquel的版本(以及Teresa的建议)的扩展性稍强:

    double[] distro=new double[]{.4,.1,.3,.2};        
    int N;
    System.out.println();
    System.out.print("Enter an integer number: ");
    Scanner input = new Scanner(System.in);
    N = input.nextInt();
    int[] a = new int[N];
    Random generator = new Random();
    outer:
    for(int i = 0; i < a.length; i++)
    {
        double rand=generator.nextDouble();
        double val=0;
        for(int j=1;j<distro.length;j++){
            val+=distro[j-1];
            if(rand<val){
                a[i]=j;
                continue outer;
            }
        }
        a[i]=distro.length;
    }

For the specific problem you gave above, the solutions provided by others work very well and the alias method would be overkill. 对于您上面给出的特定问题,其他人提供的解决方案效果很好,而别名方法可能会显得过大。 However, you said in a comment that you were actually going to use this in a distribution with a much larger range. 但是,您在评论中说,您实际上将在范围更大的发行版中使用它。 In that case, the overhead of setting up an alias table may be worthwhile to get the O(1) behavior for actually generating values. 在那种情况下,建立别名表的开销对于实际生成值的O(1)行为可能是值得的。

Here's source in Java. 这是Java的源代码。 It's easy to revert it back to using Java's stock Random if you don't want to grab Mersenne Twister: 如果您不想使用Mersenne Twister,可以很容易地将其恢复为使用Java的Random

/*
 * Created on Mar 12, 2007
 *    Feb 13, 2011: Updated to use Mersenne Twister - pjs
 */
package edu.nps.or.simutils;

import java.lang.IllegalArgumentException;
import java.text.DecimalFormat;
import java.util.Comparator;
import java.util.Stack;
import java.util.PriorityQueue;
import java.util.Random;

import net.goui.util.MTRandom;

public class AliasTable<V> {
   private static Random r = new MTRandom();
   private static DecimalFormat df2 = new DecimalFormat(" 0.00;-0.00");

   private V[] primary;
   private V[] alias;
   private double[] primaryP;
   private double[] primaryPgivenCol;

   private static boolean notCloseEnough(double target, double value) {
      return Math.abs(target - value) > 1E-10;
   }

   /**
    * Constructs the AliasTable given the set of values
    * and corresponding probabilities.
    * @param value
    *   An array of the set of outcome values for the distribution. 
    * @param pOfValue
    *   An array of corresponding probabilities for each outcome.
    * @throws IllegalArgumentException
    *   The values and probability arrays must be of the same length,
    *   the probabilities must all be positive, and they must sum to one.
    */
   public AliasTable(V[] value, double[] pOfValue) {
      super();      
      if (value.length != pOfValue.length) {
         throw new IllegalArgumentException(
               "Args to AliasTable must be vectors of the same length.");
      }
      double total = 0.0;
      for (double d : pOfValue) {
         if (d < 0) {
            throw new
               IllegalArgumentException("p_values must all be positive.");
         }
         total += d;
      }
      if (notCloseEnough(1.0, total)) {
         throw new IllegalArgumentException("p_values must sum to 1.0");
      }

      // Done with the safety checks, now let's do the work...

      // Cloning the values prevents people from changing outcomes
      // after the fact.
      primary = value.clone();
      alias = value.clone();
      primaryP = pOfValue.clone();
      primaryPgivenCol = new double[primary.length];
      for (int i = 0; i < primaryPgivenCol.length; ++i) {
         primaryPgivenCol[i] = 1.0;
      }
      double equiProb = 1.0 / primary.length;

      /*
       * Internal classes are UGLY!!!!
       * We're what you call experts.  Don't try this at home!
       */
      class pComparator implements Comparator<Integer> {
         public int compare(Integer i1, Integer i2) {
            return primaryP[i1] < primaryP[i2] ? -1 : 1;
         }
      }

      PriorityQueue<Integer> deficitSet =
         new PriorityQueue<Integer>(primary.length, new pComparator());
      Stack<Integer> surplusSet = new Stack<Integer>();

      // initial allocation of values to deficit/surplus sets
      for (int i = 0; i < primary.length; ++i) {
         if (notCloseEnough(equiProb, primaryP[i])) {
            if (primaryP[i] < equiProb) {
               deficitSet.add(i);
            } else {
               surplusSet.add(i);
            }
         }
      }

      /*
       * Pull the largest deficit element from what remains.  Grab as
       * much probability as you need from a surplus element.  Re-allocate
       * the surplus element based on the amount of probability taken from
       * it to the deficit, surplus, or completed set.
       * 
       * Lather, rinse, repeat.
       */
      while (!deficitSet.isEmpty()) {
         int deficitColumn = deficitSet.poll();
         int surplusColumn = surplusSet.pop();
         primaryPgivenCol[deficitColumn] = primaryP[deficitColumn] / equiProb;
         alias[deficitColumn] = primary[surplusColumn];
         primaryP[surplusColumn] -= equiProb - primaryP[deficitColumn];
         if (notCloseEnough(equiProb, primaryP[surplusColumn])) {
            if (primaryP[surplusColumn] < equiProb) {
               deficitSet.add(surplusColumn);
            } else {
               surplusSet.add(surplusColumn);
            }
         }
      }
   }

   /**
    * Generate a value from the input distribution.  The alias table
    * does this in O(1) time, regardless of the number of elements in
    * the distribution.
    * @return
    *   A value from the specified distribution.
    */
   public V generate() {
      int column = (int) (primary.length * r.nextDouble());
      return r.nextDouble() <= primaryPgivenCol[column] ?
                  primary[column] : alias[column];
   }

   public void printAliasTable() {
      System.err.println("Primary\t\tprimaryPgivenCol\tAlias");
      for(int i = 0; i < primary.length; ++i) {
         System.err.println(primary[i] + "\t\t\t"
            + df2.format(primaryPgivenCol[i]) + "\t\t" + alias[i]);
      }
      System.err.println();
   }
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM