---
title: "The LKJ prior vs the Wishart prior"
date: "2016-03-11"
output: html_document
---
***(latest update : `r Sys.time()`)***
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE, fig.align="center")
library(ggplot2)
fscale <- 1
```
As I noted at the end of [this article](http://stla.github.io/stlapblog/posts/MixedRepeatModel.html), JAGS returned an overestimate of the between-standard deviation $\sigma_{b_2}$.
This is how I simulated the data, with `I=3` (number of groups) and `J=4` (number of replicates per group):
```{r data}
simdata <- function(I, J){
set.seed(444)
### simulation of overall means ###
Mu.t1 <- 20
Mu.t2 <- 5
Mu <- c(Mu.t1, Mu.t2)
names(Mu) <- c("t1", "t2")
sigmab.t1 <- 8
sigmab.t2 <- 1
rho <- 0.2
Sigma <- rbind(
c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2),
c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2)
)
mu <- mvtnorm::rmvnorm(I, Mu, Sigma)
### simulation within-lots ###
sigmaw.t1 <- 2
sigmaw.t2 <- 0.5
y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1)))
y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2)))
### constructs the dataset ####
Timepoint <- rep(c("t1", "t2"), each=I*J)
Group <- paste0("grp", rep(gl(I,J), times=2))
Repeat <- rep(1:J, times=2*I)
dat <- data.frame(
Timepoint=Timepoint,
Group=Group,
Repeat=Repeat,
y=c(y.t1,y.t2)
)
dat$Timepoint <- relevel(dat$Timepoint, "t1")
return(dat)
}
```
Let us try JAGS on the data simulated with `I=100`:
```{r}
dat <- simdata(I=100, J=4)
```
First note that the `lme` estimates are quite good:
```{r nlme, message=FALSE}
library(nlme)
lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )),
weights = varIdent(form = ~ Group:Timepoint | Timepoint) )
```
Now let us run JAGS (see the previous article for the code not shown here):
```{r jagsmodel, echo=FALSE}
dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
jagsfile <- "JAGSmodel.txt"
jagsmodel <- function(){
for(i in 1:ngroups){
mu[i,1:2] ~ dmnorm(Mu[1:2], Omega[1:2,1:2])
}
for(k in 1:n){
y[k] ~ dnorm(mu[group[k], timepoint[k]], precw[timepoint[k]])
}
Omega ~ dwish(Omega0, df0)
Mu[1] ~ dnorm(0, 0.001) # overall mean timepoint 1
Mu[2] ~ dnorm(0, 0.001) # overall mean timepoint 2
precw[1] ~ dgamma(1, 0.001) # inverse within variance timepoint 1
precw[2] ~ dgamma(1, 0.001) # inverse within variance timepoint 2
sigmaw1 <- 1/sqrt(precw[1])
sigmaw2 <- 1/sqrt(precw[2])
Sigma <- inverse(Omega)
sigmab1 <- sqrt(Sigma[1,1])
sigmab2 <- sqrt(Sigma[2,2])
rhob <- Sigma[1,2]/(sigmab1*sigmab2)
}
R2WinBUGS::write.model(jagsmodel, jagsfile)
jagsdata <- list(y=dat$y, ngroups=nlevels(dat$Group), n=length(dat$y),
timepoint=dat$timepoint, group=dat$group,
Omega0 = 100*diag(2), df0=2)
```
```{r JAGSinits, echo=FALSE}
estimates <- function(dat, perturb=FALSE){
if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1)
mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE)
Mu <- colMeans(mu)
Omega <- solve(cov(mu))
precw1 <- mean(1/aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=var)$y)
precw2 <- mean(1/aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=var)$y)
precw <- c(precw1, precw2)
return(list(mu=mu, Mu=Mu, Omega=Omega, precw=precw))
}
inits1 <- estimates(dat)
inits2 <- estimates(dat, perturb=TRUE)
inits3 <- estimates(dat, perturb=TRUE)
inits <- list(inits1,inits2,inits3)
```
```{r jagssamples, message=FALSE, cache=TRUE, collapse=TRUE}
library(rjags)
jagsmodel <- jags.model(jagsfile,
data = jagsdata,
inits = inits,
n.chains = length(inits))
update(jagsmodel, 5000) # warm-up
jagssamples <- coda.samples(jagsmodel,
c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"),
n.iter= 10000)
```
Below are the summary statistics of the posterior samples:
```{r}
summary(jagssamples)
```
Again, $\sigma_{b_2}$ is overestimated: its true value ($=1$) is less than the lower bound of the $95\%$-credible interval ($\approx 1.31$). The other estimates are quite good.
## Using the LKJ prior
The above problem is possibly due to the Wishart prior on the covariance matrix.
Stan allows to use a [LKJ prior](http://stla.github.io/stlapblog/posts/StanLKJprior.html) on the correlation matrix. We will run it on the small dataset:
```{r}
dat <- simdata(I=3, J=4)
dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
```
```{r stanmodel, message=FALSE}
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
stancode <- 'data {
int N; // number of observations
real y[N]; // observations
int ngroups; // number of groups
int group[N]; // group indices
int timepoint[N]; // timepoint indices
}
parameters {
vector[2] Mu;
vector[2] mu[ngroups]; // group means
cholesky_factor_corr[2] L;
vector[2] sigma_b;
vector[2] sigma_w;
}
model {
sigma_w ~ cauchy(0, 5);
for(k in 1:N){
y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]);
}
sigma_b ~ cauchy(0, 5);
L ~ lkj_corr_cholesky(1);
Mu ~ normal(0, 25);
for(j in 1:ngroups){
mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L));
}
}
generated quantities {
matrix[2,2] Omega;
matrix[2,2] Sigma;
real rho_b;
Omega <- multiply_lower_tri_self_transpose(L);
Sigma <- quad_form_diag(Omega, sigma_b);
rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]);
}'
### compile Stan model
stanmodel <- stan_model(model_code = stancode, model_name="stanmodel")
### Stan data
standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group),
timepoint=dat$timepoint, group=dat$group)
### Stan initial values
estimates <- function(dat, perturb=FALSE){
if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1)
mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE)
Mu <- colMeans(mu)
sigma_b <- sqrt(diag(var(mu)))
L <- t(chol(cor(mu)))
sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y)
sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y)
return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2)))
}
inits <- function(chain_id){
values <- estimates(dat, perturb = chain_id > 1)
return(values)
}
```
We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the `control` argument, I increase `adapt_delta` and `max_treedepth`:
```{r stansampling, message=FALSE, warning=FALSE}
### run Stan
stansamples <- sampling(stanmodel, data = standata, init=inits,
iter = 15000, warmup = 5000, chains = 4,
control=list(adapt_delta=0.999, max_treedepth=15))
### outputs
library(coda)
codasamples <- do.call(mcmc.list,
plyr::alply(rstan::extract(stansamples, permuted=FALSE,
pars = c("Mu", "sigma_b", "sigma_w", "rho_b")),
2, mcmc))
summary(codasamples)
```
As compared to the [JAGS estimates](http://stla.github.io/stlapblog/posts/MixedRepeatModel.html) (given at the end), the estimates of $\sigma_{b_2}$ and $\rho_b$ obtained with Stan are really better. Note also the JAGS returned a huge credible interval for $\mu_2$.