简体   繁体   中英

How to optimize rejection sampling

I have an std::map mymap that I am trying to sample based on the values for each key. I have set up an algorithm based on rejection sampling that seems to be working, however it is extremely slow (this algorithm gets called thousands of times in my program).

So I am wondering if this would be the best approach or if there is something quicker/more efficient that I could be doing instead.

Here is what I have so far below:

std::map<int, float> mymap; //My map that I am sampling

//These three floats are precomputed
int minKey;  //Min key in the map.  
int maxKey;  //Max key in the map.  
float maxValue; //Max value in the map.  

float x1, x2; //Two random variables;
int key;
float value;
do 
{
    x1 = (float)rand()/(float)RAND_MAX;
    x2 = maxValue * (float)rand()/(float)RAND_MAX;
    key = minKey*(1.0-x1) + maxKey*x1; //Linearly interpolate random value to get key;
    value = mymap[key]; //Get value;
} while(x2 > value) 


return std::pair<int, float)(key, value);

^So what I am doing above is uniformly randomly selecting a key. Then creating another random variable and comparing it against that key's value. If it is larger, repeat the process. This way, keys with higher values get sampled more often than keys with lower values. However, the do-while loop can loop many times before finding an acceptable key-value pair to sample and this is causing quite the bottleneck in my application

EDIT

Also, is it necessary for me to do any adjusting to my samples since they are biased here? I know that in monte carlo integration, you have to divide the value of the sample by the PDF of that sample...but I'm not sure if that would apply here. If it does apply, how would I find the PDF?

If you want to bias your sample linearly in proportion to the values, it's easy to do.

Start by calculating the sum of all the values.

Now generate a single random floating-point value between 0 and the sum.

Iterate through the map, summing the values as you go. When the sum is greater than the random value calculated earlier, you've found your sample.

If you'll be doing this repeatedly on an unchanging map, you can create a vector of sums and do a binary search for the random value.

Rejection sampling is primarily useful for continuous distributions. What you need is to sample a discrete distribution . Fortunately, this is part of STL in C++11. So, adapted from the sample of std::discrete_distribution :

#include <iostream>
#include <map>
#include <random>

template <typename T>
class sampler
{
    std::vector<T> keys;
    std::discrete_distribution<T> distr;

public:
    sampler(const std::vector<T>& keys, const std::vector<float>& prob) :
        keys(keys), distr(prob.begin(), prob.end()) { }

    T operator()()
    {
        static std::random_device rd;
        static std::mt19937 gen(rd());
        return keys[distr(gen)];
    }
};

int main()
{
    using T = int;
    sampler<T> samp({19, 54, 192, 732}, {.1, .2, .4, .3});
    std::map<T, size_t> hist;

    for (size_t n = 0; n < 10000; ++n)
        ++hist[samp()];

    for (auto i: hist)
    {
        std::cout << i.first << " generated " <<
        i.second << " times" << std::endl;
    }
}

Output:

19 generated 1010 times
54 generated 2028 times
192 generated 3957 times
732 generated 3005 times

Vectors keys and prob contain separately the keys and values (probabilities) of your map. This is because std::discrete_distribution takes into account only the probabilities.

Note that operator() cannot be const because std::discrete_distribution changes state (naturally) at every sample.

Also note that even you implement sampling yourself using the cumulative distribution and binary search (whereby sampling is logarithmic-time in the size of your domain), there are more efficient (constant-time) sampling methods like the alias method . I am not sure what method is used by std::discrete_distribution , however.

一种可能是使用带有未知坏键的第二个map (或set )(将所有键放在此处,一旦拒绝了某个键,因为它大于初始随机变量,则从映射中将其删除-然后您在未知坏集中而不是在整个地图中搜索密钥...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM