I am trying to understand where I am going wrong with rstan
. I have figured out a workaround, but it seems like there should be a better option for graphing draws from the posterior than what I have come up with.
I am trying to learn how to use rstan
for modeling a Gaussian process related to another question I have open on CV (shameless plug but if you have ideas that can help out there I am all ears).
I figured as a first step I would try going through the stan documentation examples of a Gaussian process. So I built a model simply designed to draw random squared exponential covariance functions.
library(rstan)
library(rstanarm)
library(bayesplot)
library(ggplot2)
options(mc.cores=parallel::detectCores())
rstan_options(auto_write = TRUE)
x<-seq(0, 30, by=.01)
model<-'
data{
int<lower=1> N;
real x[N];
}
transformed data {
matrix[N, N] L;
matrix[N, N] K;
vector[N] mu = rep_vector(0, N);
for (i in 1:(N - 1)) {
K[i, i] = 1 + 0.1;
for (j in (i + 1):N) {
K[i, j] = exp(-0.5 * square(x[i] - x[j]));
K[j, i] = K[i, j];
}
}
K[N, N] = 1 + 0.1;
L = cholesky_decompose(K);
}
parameters {
vector[N] eta;
}
model {
eta ~ normal(0, 1);
}
generated quantities {
vector[N] y;
y = mu + L*eta;
}
'
I followed the documentation's suggestion of including a Cholesky decomposition on transformed data.
Using stan
I fit the model as follows:
dat<-list(N=length(x),
x=x)
fit <- stan(model_code = model,
data = dat,
iter = 1000,
chains = 1,
pars = c('y', 'eta'),
control = list(adapt_delta=.99,
max_treedepth=10)
)
I can visualize the posterior distributions of each of my draws using the following code:
posterior<-as.matrix(fit)
mcmc_areas(posterior,
pars=c('y[1]', 'y[2]'),
prob = .90
)
Which produces:
I really want to look at the results of each process (not all 500 but some random draws thereof).
I tried multiple alternative strategies and eventually landed on the following:
post.y<-extract(fit, pars='y')
draws<-sample(1:500, size = 10)
DF<-data.frame(Time=x, y=colMeans(post.y$y), Draw=rep('Mu', length(x)))
for(i in 1:length(draws)){
DF.temp<-data.frame(Time=x, y=post.y$y[i,], Draw=rep(paste0('posterior', i), length(x)))
DF<-rbind(DF, DF.temp)
}
g1<-ggplot(aes(x=Time, y=y), data=DF)
g2<-g1+geom_line(aes(x=Time, y=y, group=Draw, color=Draw), data=DF[DF$Draw!='Mu',], alpha=.25, show.legend = F)
g3<-g2+geom_line(aes(x=Time, y=y), data=DF[DF$Draw=='Mu',], lwd=1.5)
g3
This seems like a lot of extra hoops to jump through. I tried alternative approaches using other functions in the rstan
family (eg, ppc_dens_overlay
), but they all resulted in errors or did not return what I wanted.
So my question here is really about alternative, simpler options I can use to visualize the overall average of my draws for each value of $y_i$ as well as the overall mean of all draws for each value (which should be 0 in this case but may not in other cases when data changes over time in a structure way).
I am relatively new to rstan
(have used rbugs
and rjags
) so I may be simply unaware of some simple set of functions that can make this process easier.
Thanks in advance for any help.
You could reproduce your second figure with a bit less code using matplot
, which conveniently works with matrix data.
post.y <- rstan::extract(fit, 'y')$y
post.y.sub <- post.y[sample(1:nrow(post.y), 10),]
matplot(x, t(post.y.sub), type = 'l', lty = 1, col = adjustcolor(palette(), 0.25))
lines(colMeans(post.y) ~ x, lwd = 2)
If you prefer ggplot2, the hard part is getting the posterior samples into a data frame. I find the dplyr
and tidyr
libraries helpful here. It looks like a lot of code, but it's flexible when your models get more complicated.
library(dplyr)
library(tidyr)
df.rep <- post.y %>%
t() %>%
as.data.frame() %>%
mutate(x = x) %>%
gather(rep, post.y, -x)
df.mean <- df.rep %>%
group_by(x) %>%
summarize(mu = mean(post.y))
df.rep.sub <- df.rep %>%
filter(rep %in% sample(unique(rep), 10))
ggplot() +
geom_line(data = df.rep.sub, aes(x, post.y, col = rep), alpha = 0.25, show.legend = F) +
geom_line(data = df.mean, aes(x, mu), lwd = 1.5)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.