[英]Convert numpy function to theano
我正在使用PyMC3
來計算一些我不會在這里討論的內容,但是如果您感興趣的話,可以從此鏈接中獲得想法。
“ 2-lambdas”情況基本上是一個開關函數,需要將其編譯為Theano
函數,以避免dtype
錯誤,它看起來像這樣:
import theano
from theano.tensor import lscalar, dscalar, lvector, dvector, argsort
@theano.compile.ops.as_op(itypes=[lscalar, dscalar, dscalar], otypes=[dvector])
def lambda_2_distributions(tau, lambda_1, lambda_2):
"""
Return values of `lambda_` for each observation based on the
transition value `tau`.
"""
out = zeros(num_observations)
out[: tau] = lambda_1 # lambda before tau is lambda1
out[tau:] = lambda_2 # lambda after (and including) tau is lambda2
return out
我試圖將其概括為適用於taus.shape[0] = lambdas.shape[0] - 1
'n-lambdas',但是我只能想出這種極其緩慢的numpy
實現。
@theano.compile.ops.as_op(itypes=[lvector, dvector], otypes=[dvector])
def lambda_n_distributions(taus, lambdas):
out = zeros(num_observations)
np_tau_indices = argsort(taus).eval()
num_taus = taus.shape[0]
for t in range(num_taus):
if t == 0:
out[: taus[np_tau_indices[t]]] = lambdas[t]
elif t == num_taus - 1:
out[taus[np_tau_indices[t]]:] = lambdas[t + 1]
else:
out[taus[np_tau_indices[t]]: taus[np_tau_indices[t + 1]]] = lambdas[t]
return out
關於如何使用純Theano
加快速度的任何想法(避免調用.eval()
)? 自從我使用它已經有幾年了,所以不知道正確的方法。
不建議使用開關功能,因為它會破壞參數空間的幾何形狀,並使使用NUTS之類的現代采樣器采樣變得困難。
相反,您可以嘗試使用連續放松的開關功能對其進行建模。 這里的主要思想是將第一個轉換點之前的速率建模為基線; 並在每個切換點之后添加來自邏輯函數的預測:
def logistic(L, x0, k=500, t=np.linspace(0., 1., 1000)):
return L/(1+tt.exp(-k*(t_-x0)))
with pm.Model() as m2:
lambda0 = pm.Normal('lambda0', mu, sd=sd)
lambdad = pm.Normal('lambdad', 0, sd=sd, shape=nbreak-1)
trafo = Composed(pm.distributions.transforms.LogOdds(), Ordered())
b = pm.Beta('b', 1., 1., shape=nbreak-1, transform=trafo,
testval=[0.3, 0.5])
theta_ = pm.Deterministic('theta', tt.exp(lambda0 +
logistic(lambdad[0], b[0]) +
logistic(lambdad[1], b[1])))
obs = pm.Poisson('obs', theta_, observed=y)
trace = pm.sample(1000, tune=1000)
我在這里也使用了一些技巧,例如,復合轉換尚未在PyMC3代碼庫上進行。 您可以在此處查看完整的代碼: https : //gist.github.com/junpenglao/f7098c8e0d6eadc61b3e1bc8525dd90d
如果您還有其他問題,請將模型和(模擬的)數據發布到https://discourse.pymc.io 。 我會更定期地檢查和回答PyMC3的話語。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.