--- title: "The LKJ prior vs the Wishart prior" date: "2016-03-11" output: html_document --- ***(latest update : `r Sys.time()`)***
```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, fig.align="center") library(ggplot2) fscale <- 1 ``` As I noted at the end of [this article](http://stla.github.io/stlapblog/posts/MixedRepeatModel.html), JAGS returned an overestimate of the between-standard deviation $\sigma_{b_2}$. This is how I simulated the data, with `I=3` (number of groups) and `J=4` (number of replicates per group): ```{r data} simdata <- function(I, J){ set.seed(444) ### simulation of overall means ### Mu.t1 <- 20 Mu.t2 <- 5 Mu <- c(Mu.t1, Mu.t2) names(Mu) <- c("t1", "t2") sigmab.t1 <- 8 sigmab.t2 <- 1 rho <- 0.2 Sigma <- rbind( c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2), c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2) ) mu <- mvtnorm::rmvnorm(I, Mu, Sigma) ### simulation within-lots ### sigmaw.t1 <- 2 sigmaw.t2 <- 0.5 y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1))) y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2))) ### constructs the dataset #### Timepoint <- rep(c("t1", "t2"), each=I*J) Group <- paste0("grp", rep(gl(I,J), times=2)) Repeat <- rep(1:J, times=2*I) dat <- data.frame( Timepoint=Timepoint, Group=Group, Repeat=Repeat, y=c(y.t1,y.t2) ) dat$Timepoint <- relevel(dat$Timepoint, "t1") return(dat) } ``` Let us try JAGS on the data simulated with `I=100`: ```{r} dat <- simdata(I=100, J=4) ``` First note that the `lme` estimates are quite good: ```{r nlme, message=FALSE} library(nlme) lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )), weights = varIdent(form = ~ Group:Timepoint | Timepoint) ) ``` Now let us run JAGS (see the previous article for the code not shown here): ```{r jagsmodel, echo=FALSE} dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group)) jagsfile <- "JAGSmodel.txt" jagsmodel <- function(){ for(i in 1:ngroups){ mu[i,1:2] ~ dmnorm(Mu[1:2], Omega[1:2,1:2]) } for(k in 1:n){ y[k] ~ dnorm(mu[group[k], timepoint[k]], precw[timepoint[k]]) } Omega ~ dwish(Omega0, df0) Mu[1] ~ dnorm(0, 0.001) # overall mean timepoint 1 Mu[2] ~ dnorm(0, 0.001) # overall mean timepoint 2 precw[1] ~ dgamma(1, 0.001) # inverse within variance timepoint 1 precw[2] ~ dgamma(1, 0.001) # inverse within variance timepoint 2 sigmaw1 <- 1/sqrt(precw[1]) sigmaw2 <- 1/sqrt(precw[2]) Sigma <- inverse(Omega) sigmab1 <- sqrt(Sigma[1,1]) sigmab2 <- sqrt(Sigma[2,2]) rhob <- Sigma[1,2]/(sigmab1*sigmab2) } R2WinBUGS::write.model(jagsmodel, jagsfile) jagsdata <- list(y=dat$y, ngroups=nlevels(dat$Group), n=length(dat$y), timepoint=dat$timepoint, group=dat$group, Omega0 = 100*diag(2), df0=2) ``` ```{r JAGSinits, echo=FALSE} estimates <- function(dat, perturb=FALSE){ if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1) mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE) Mu <- colMeans(mu) Omega <- solve(cov(mu)) precw1 <- mean(1/aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=var)$y) precw2 <- mean(1/aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=var)$y) precw <- c(precw1, precw2) return(list(mu=mu, Mu=Mu, Omega=Omega, precw=precw)) } inits1 <- estimates(dat) inits2 <- estimates(dat, perturb=TRUE) inits3 <- estimates(dat, perturb=TRUE) inits <- list(inits1,inits2,inits3) ``` ```{r jagssamples, message=FALSE, cache=TRUE, collapse=TRUE} library(rjags) jagsmodel <- jags.model(jagsfile, data = jagsdata, inits = inits, n.chains = length(inits)) update(jagsmodel, 5000) # warm-up jagssamples <- coda.samples(jagsmodel, c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"), n.iter= 10000) ``` Below are the summary statistics of the posterior samples: ```{r} summary(jagssamples) ``` Again, $\sigma_{b_2}$ is overestimated: its true value ($=1$) is less than the lower bound of the $95\%$-credible interval ($\approx 1.31$). The other estimates are quite good. ## Using the LKJ prior The above problem is possibly due to the Wishart prior on the covariance matrix. Stan allows to use a [LKJ prior](http://stla.github.io/stlapblog/posts/StanLKJprior.html) on the correlation matrix. We will run it on the small dataset: ```{r} dat <- simdata(I=3, J=4) dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group)) ``` ```{r stanmodel, message=FALSE} library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores()) stancode <- 'data { int N; // number of observations real y[N]; // observations int ngroups; // number of groups int group[N]; // group indices int timepoint[N]; // timepoint indices } parameters { vector[2] Mu; vector[2] mu[ngroups]; // group means cholesky_factor_corr[2] L; vector[2] sigma_b; vector[2] sigma_w; } model { sigma_w ~ cauchy(0, 5); for(k in 1:N){ y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]); } sigma_b ~ cauchy(0, 5); L ~ lkj_corr_cholesky(1); Mu ~ normal(0, 25); for(j in 1:ngroups){ mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L)); } } generated quantities { matrix[2,2] Omega; matrix[2,2] Sigma; real rho_b; Omega <- multiply_lower_tri_self_transpose(L); Sigma <- quad_form_diag(Omega, sigma_b); rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]); }' ### compile Stan model stanmodel <- stan_model(model_code = stancode, model_name="stanmodel") ### Stan data standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group), timepoint=dat$timepoint, group=dat$group) ### Stan initial values estimates <- function(dat, perturb=FALSE){ if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1) mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE) Mu <- colMeans(mu) sigma_b <- sqrt(diag(var(mu))) L <- t(chol(cor(mu))) sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y) sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y) return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2))) } inits <- function(chain_id){ values <- estimates(dat, perturb = chain_id > 1) return(values) } ``` We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the `control` argument, I increase `adapt_delta` and `max_treedepth`: ```{r stansampling, message=FALSE, warning=FALSE} ### run Stan stansamples <- sampling(stanmodel, data = standata, init=inits, iter = 15000, warmup = 5000, chains = 4, control=list(adapt_delta=0.999, max_treedepth=15)) ### outputs library(coda) codasamples <- do.call(mcmc.list, plyr::alply(rstan::extract(stansamples, permuted=FALSE, pars = c("Mu", "sigma_b", "sigma_w", "rho_b")), 2, mcmc)) summary(codasamples) ``` As compared to the [JAGS estimates](http://stla.github.io/stlapblog/posts/MixedRepeatModel.html) (given at the end), the estimates of $\sigma_{b_2}$ and $\rho_b$ obtained with Stan are really better. Note also the JAGS returned a huge credible interval for $\mu_2$.