library(BART)

# Dirichlet Prior
FUNCTION dprior
INPUT(alpha, F0, data)
  s<-DRAW from F0
  V<-DRAW from Beta(1, alpha)
  w[1]<-1
  w[2:n]<-Compute stick-breaking process into w[i]; i = 2,...,n
RETURN function(size = 10000){sample(s, size, prob = w, replace = TRUE)}

# Dirichlet Posterior
FUNCTION dpost
INPUT(alpha, F0, data)
  SET n as length of data
  SET F_n as empirical CDF(data)
  SET F_bar as n/(n+alpha)F_n+alpha/(n+alpha)F0
RETURN dprior(alpa+n, F_bar)


FOR i in 1,......,NSIM

    INPUT DRAW f0 from N(0, 3)
    INPUT DRAW F0 as a cdf
      # BART
      bart.tot<-BARTMODEL 
      # Compute the individual treatment effect (# as specified in Hill, 2011)
      cate_mat<-bart.tot$yhat.train - bart.tot$yhat.test

        FOR j in 1,.....,NROW(cate_mat)
            #A row of cate_mat records Individual Treatment effects
            SET data as cate_mat[j,]
            Fpost<-dpost(5, F0, data)
            posterior<-Fpost()
            tau_post[[j]]<-mean(posterior)

      mntau[i]<-mean(unlist(tau_post))
      low95[i]<-HPD(as.mcmc(unlist(tau_post)), prob = 0.95)[1]
      high95[i]<-HPD(as.mcmc(unlist(tau_post)), prob = 0.95)[2]
  

###########################
# Compute 95% coverage
FUNCTION isCovered
    INPUT i
      ifelse(low95[i]<= ate_tau[[i]] & ate_tau[[i]] <= high95[i], 1, 0)
REPEAT length(ate_tau)

SET perCoverage as sum(unlist(coverage)/length(ate_tau)
OUTPUT perCoverage

# Interval width
width<-high95 - low95
OUTPUT mean(width)

# MSE 
mse<-mean(ate_tau[[i]] - mntau[i)^2)
REPEAT length(ate_tau)
OUTPUT median(unlist(mse)
