簡體   English   中英

以張量流中的間隔為條件對高斯進行采樣

[英]Sampling a Gaussian conditioned on an interval in tensorflow

假設X是高斯,N(0,1),並且我們想要對於常數x1,x2給出x1 <= X <= x2的樣本X. 如何在tensorflow中執行此操作?

這使用了special_math函數ndtr和ndtri,高斯分布和逆分布函數。 由於目前無法通過搜索API找到這些函數,因此其價值包括此處。

import numpy as np
import tensorflow as tf
from tensorflow.python.ops.distributions import special_math as dsm

import matplotlib.pylab as pl

#assuming x1 < x2
def tf_conditioned_normal(x1,x2, dtype = tf.float32):

    Fx1 = dsm.ndtr(x1)
    Fx2 = dsm.ndtr(x2)
    gamma = tf.random_uniform([1], dtype = dtype)            
    return dsm.ndtri(Fx1  + gamma*(Fx2 - Fx1))


if __name__ == '__main__':

    graph = tf.Graph()
    with graph.as_default():

        t_x1ph = tf.placeholder(tf.float32,[])
        t_x2ph = tf.placeholder(tf.float32,[])

        t_cn = tf_conditioned_normal(t_x1ph,t_x2ph, dtype = tf.float32)

        t_rn = tf.random_normal([1])


    sess  = tf.Session(graph = graph)

    print 'Conditioned...'

    x1 = -5.
    x2 = -1.

    N = 5000
    res = np.zeros(N)

    for i in xrange(N):        
        res[i] = sess.run(t_cn,
            feed_dict = {
                t_x1ph : x1 ,
                t_x2ph :  x2 ,
                }
            )    


    print 'Regular...'

    Nn = 50000
    nres = np.zeros(Nn)

    for i in xrange(Nn):        
        nres[i] = sess.run(t_rn)    

    nres = nres[ (nres>=x1) & (nres <= x2) ]

    pl.figure()
    tmp = pl.hist(res, np.linspace(x1,x2,200), normed = True)
    tmp = pl.hist(nres, np.linspace(x1,x2,200), normed = True, alpha = 0.7)
    pl.show()

你可以簡單地做到 -

import tensorflow as tf

x1 = tf.constant(0.)
x2 = tf.constant(1.)
N = tf.constant(10)

# Define a batch of one scalar valued Normals.
# The mean is 0. and standard deviation 1.
dist = tf.distributions.Normal(loc=[0.], scale=[1.])

# Get N samples, returning a N x 1 tensor.
sample = dist.sample([N])

# Put the filters and get boolean mask
filters = tf.logical_and(sample>x1, sample<x2)

# Put the filter in place
final_sample = tf.boolean_mask(sample, filters)

# Check output
sess = tf.InteractiveSession()
print(final_sample.eval())

產量

[0.11488124 0.38626793 0.3822059  0.3888869 ]

唯一的挑戰是你必須保持采樣和過濾,直到你達到你的N 我把它作為一個微不足道的補充。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM