简体   繁体   中英

How to specify size for bernoulli distribution with pymc3?

In trying to make my way through Bayesian Methods for Hackers, which is in pymc, I came across this code:

first_coin_flips = pm.Bernoulli("first_flips", 0.5, size=N)

I've tried to translate this to pymc3 with the following, but it just returns a numpy array, rather than a tensor (?):

first_coin_flips = pm.Bernoulli("first_flips", 0.5).random(size=50)

The reason the size matters is that it's used later on in a deterministic variable. Here's the entirety of the code that I have so far:

import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np
import mpld3
import theano.tensor as tt

model = pm.Model()
with model:
    N = 100
    p = pm.Uniform("cheating_freq", 0, 1)
    true_answers = pm.Bernoulli("truths", p)
    print(true_answers)
    first_coin_flips = pm.Bernoulli("first_flips", 0.5)
    second_coin_flips = pm.Bernoulli("second_flips", 0.5)
    #  print(first_coin_flips.value)

    # Create model variables
    def calc_p(true_answers, first_coin_flips, second_coin_flips):
        observed = first_coin_flips * true_answers + (1-first_coin_flips) * second_coin_flips
        # NOTE: Where I think the size param matters, since we're dividing by it
        return observed.sum() / float(N)

    calced_p = pm.Deterministic("observed", calc_p(true_answers, first_coin_flips, second_coin_flips))
    step = pm.Metropolis(model.free_RVs)
    trace = pm.sample(1000, tune=500, step=step)
    pm.traceplot(trace)

    html = mpld3.fig_to_html(plt.gcf())
    with open("output.html", 'w') as f:
        f.write(html)
        f.close()

And the output:

产量

The coin flips and uniform cheating_freq output look correct, but the observed doesn't look like anything to me, and I think it's because I'm not translating that size param correctly.

指定伯努利分布大小的pymc3方法是使用shape参数,如:

first_coin_flips = pm.Bernoulli("first_flips", 0.5, shape=N)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM