#### Introduction ####
# 1. This file includes codes to reproduce results in the data analysis in the paper
# ''A Non-parametric Direct Learning Approach to Heterogeneous Treatment Effect Estimation under Unmeasured Confounding''.

# 2. The authors run these on a Macbook with 2.9 GHz Quad-Core Intel Core i7 processor and 16 GB memory.
# The time used for the entire program including generating figures was approximately 8hrs.

# 3. The code sections in this file have to be executed in the original order.

#### Packages ####
library(MASS)
library(glmnet)
library(pdist)
library(tidyverse)
library(grf)
library(rpart)
library(caret)
library(locfit)
library(haven)
library(rpart)
library(rpart.plot)
library(mgcv)
library(plotly)
library(gridExtra)

#### IV-DL functions ####

augments <- function(Y,A,Z,Dhat, pi_Zhat, cate_pre = (Yp-Yn)/Dhat,
                     Yp, Yn, Ap = 1, An, Ym = (Yp+Yn)/2, Am = (Ap+An)/2){
  #Wald estimator as cate_pre

  g <- (Yp+Yn)/2
  h1 <- Yp+(A-Ap-Z*Dhat)*cate_pre/2
  h2 <- Yn+(A-An-Z*Dhat)*cate_pre/2
  h3 <- Ym+(A-Am-Z*Dhat)*cate_pre/2
  Ymr <- Z*(Y-Yn-cate_pre*(A-An)/2)/Dhat/pi_Zhat + (Yp-Yn)/Dhat
  return(list(g = g, h1 = h1, h2 = h2, h3 = h3, Ymr = Ymr,
              Yp = Yp, Yn = Yn, Ap = Ap, An = An))
}

IVDL.linear <- function(Y,X,Z,pi_Z=1/2,d,g,lambda_grid = 5^(-10:10)){
  # Check arguments
  n <- nrow(X)
  stopifnot(is.numeric(Y))
  if (length(Y) != n) stop("'Y' has different length with number of rows in X")
  if (length(Z) != n) stop("'Z' has different length with number of rows in X")

  # cv fit
  response <- 2*Z*(Y-g)/d
  fit <- cv.glmnet(x=X, y=response, weights = 1/pi_Z, lambda = lambda_grid, nfolds = 10, family = "gaussian",
                   standardize = FALSE, intercept = TRUE)

  # Optimal lambda
  opt <- which(fit$lambda %in% fit$lambda.min)
  obj <- fit$cvm[opt]
  attr(obj, "type") <- "cv score"
  if (opt == 1 | opt == length(lambda_grid))
    warning("The optimal lambda for fitting treatment effect may fall outside the window")
  coef <- c(fit$glmnet.fit$a0[opt], fit$glmnet.fit$beta[, opt])
  names(coef)[1] <- "(Intercept)"
  return(list(coef = coef, opt.lambda = lambda_grid[opt]))
}

IVDL.kernel <- function(Y,X,Z,pi_Z=1/2,d,g,
                        qtile_grid = c(0.1,0.25,0.5,0.75,0.9),
                        lambda_grid = 10^(-5:5),
                        n_folds = 5){
  # Check arguments
  n <- nrow(X)
  stopifnot(is.numeric(Y))
  if (length(Y) != n) stop("'Y' has different length with number of rows in X")
  if (length(Z) != n) stop("'Z' has different length with number of rows in X")

  #cv KRR
  response <- 2*Z*(Y-g)/d

  fit <- WLS_ker(response, X, pi_Z, qtile_grid, lambda_grid, n_folds)
  return(list(coef = fit$coef,
              opt.qtile = fit$opt.qtile,
              opt.lambda = fit$opt.lambda))
}

WLS_ker <- function(Y, X, weights = NULL,
                    qtile_grid = c(0.1,0.25,0.5,0.75,0.9),
                    lambda_grid = 10^(-5:5),
                    n_folds = 5){

  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)
  #tuning
  if(n_folds > 1){
    MSE <- matrix(nrow = n_folds, ncol = n_l*n_q)
    folds <- createFolds(1:n, k = n_folds)

    for (i in 1:n_folds) {
      ind <- folds[[i]]
      for (j in 1:n_q) {
        ker <- genKernel(X[-ind,], kernel = 'gaussian', qtile = qtile_grid[j])
        ker_te <- genKernel(X[-ind,], X[ind,], kernel = 'gaussian', qtile = qtile_grid[j])
        K <- cbind(1, ker)
        G <- rbind(0, cbind(0, ker))
        fit <- ridgereg(K, Y[-ind], G, weights[-ind], lambda_grid)
        Yhat_te <- cbind(1,ker_te)%*%fit$coef
        MSE[i,((j-1)*n_l+1):(j*n_l)] <- colMeans((sweep(x = Yhat_te, MARGIN = 1, STATS = Y[ind], FUN = "-"))^2)
      }
    }
    mMSE <- as.numeric(colMeans(MSE))

    seMSE <- apply(MSE,2,sd)/sqrt(n_folds)
    onesebound <-  mMSE[which.min(mMSE)] + seMSE[which.min(mMSE)]/2
    expand_grid_l <- rep(lambda_grid, times = n_q)
    expand_grid_q <- rep(qtile_grid, each = n_l)
    opt.qtile <- expand_grid_q[which.min(mMSE)]
    opt.lambda <- expand_grid_l[which.min(mMSE)]
    opt.lambda1se <- lambda_grid[max(which(mMSE[which(expand_grid_q==opt.qtile)]< onesebound))]

    ker <- genKernel(X, kernel = 'gaussian', qtile = opt.qtile)
    K <- cbind(1, ker)
    G <- rbind(0, cbind(0, ker))
    fit <- ridgereg(K, Y, G, weights, opt.lambda)
    fit1se <- ridgereg(K, Y, G, weights, opt.lambda1se)
  } else {
    ker <- genKernel(X, kernel = 'gaussian', qtile = 0.5)
    K <- cbind(1, ker)
    G <- rbind(0, cbind(0, ker))
    fit <- ridgereg(K, Y, G, weights, 0)
  }

  return(list(coef = as.vector(fit$coef),
              coef1se = as.vector(fit1se$coef),
              opt.qtile = opt.qtile,
              opt.lambda = opt.lambda,
              opt.lambda1se))
}

genKernel <- function(x, y, kernel = c("gaussian", "polynomial"),
                      epsilon = 1/2/quantile(d, qtile)^2, degree = 2, qtile = 0.5) {

  # Check arguments
  kernel <- match.arg(kernel)
  x <- as.matrix(x)
  y <- if (missing(y)) x else as.matrix(y)

  if (kernel == "gaussian") {
    d <- as.matrix(suppressWarnings(pdist(y, x)))
    ker <- exp(-epsilon * d^2)
    attr(ker, "epsilon") <- epsilon
  } else {
    ker <- (1 + y %*% t(x))^degree
    attr(ker, "degree") <- degree
  }

  attr(ker, "type") <- kernel
  return(ker)

}


ridgereg <- function(X, y, P, weights = NULL, lambda = 10^(-5:5)) {

  ## Minimize sum(w * (y - X * beta)^2)/n + lambda * t(beta) %*% P %*% beta
  ## X: design matrix
  ## y: response
  ## P: penalty matrix

  n <- length(y)
  N <- length(lambda)
  nlambda <- n * lambda
  if (!length(weights)) weights <- rep(1, n)

  # Normal equation (A + nlambda * P) %*% beta = B %*% y
  B <- t(X) %*% diag(weights)
  A <- B %*% X

  f <- matrix(0, n, N)
  beta <- matrix(0, ncol(X), N)
  gcv <- numeric(N)

  for (i in 1:N) {

    # Solve normal equation
    U <- ginv(A + nlambda[i] * P) %*% B
    beta[, i] <- U %*% y
    H <- X %*% U
    f[, i] <- H %*% y

    # GCV score
    rss <- sum(weights * (y - f[, i])^2)
    gcv[i] <- n * rss/(n - sum(diag(H)))^2

  }

  return(list(coef = beta, fitted = f, gcv = gcv, lambda = lambda))
}


IVDL.local <- function(Y,X, X_test, Z, pi_Z=1/2, d, g,
                       lambda_grid = 0,
                       qtile_grid = c(0.1,0.5,0.75,0.9, 1.5, 3, 10),
                       test_weights = NULL, n_folds = 5){
  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)

  # cv fit
  response <- 2*Z*(Y-g)/d

  #tuning
  MSE <- matrix(nrow = n_folds, ncol = n_q*n_l)
  folds <- createFolds(1:n, k = n_folds)
  for (i in 1:n_folds) {
    ind <- folds[[i]]
    fit <- LocalReg(response[-ind], X[-ind,], X[ind,],
                    pi_Z[-ind], lambda_grid = lambda_grid, weights = NULL,
                    qtile_grid = qtile_grid)
    SE <- (sweep(fit$pred, 1, response[ind], FUN = "-"))^2
    MSE[i,] <- colMeans(SE)
  }

  mMSE <- as.numeric(colMeans(MSE))
  seMSE <- apply(MSE,2,sd)/sqrt(n_folds)
  onesebound <-  mMSE[which.min(mMSE)] + seMSE[which.min(mMSE)]
  opt.ind.qtile <- (which.min(mMSE)-1)%/%n_l+1
  mMSE_opt.qtile <- mMSE[((opt.ind.qtile-1)*n_l+1):(opt.ind.qtile*n_l)]
  opt.ind.lambda <- which(mMSE_opt.qtile <= onesebound)
  opt.lambda <- ifelse(n_l==1, lambda_grid, max(lambda_grid[opt.ind.lambda]))
  opt.qtile <- qtile_grid[opt.ind.qtile]

  fit_test <- LocalReg(response, X, X_test, pi_Z, lambda_grid = opt.lambda, weights = test_weights, qtile_grid = opt.qtile)
  fit <- LocalReg(response, X, X, pi_Z, lambda_grid = opt.lambda, weights = NULL, qtile_grid = opt.qtile)

  return(list(coef = fit_test$coef, pred = fit_test$pred,
              test_weights = fit_test$weights, fitted = fit$pred))
}


LocalReg <- function(Y, X_train, X_new, pi_Z, lambda_grid = 10^(-5:5),
                     weights = NULL, qtile_grid = c(0.1,0.5,0.75,0.9, 1.5, 3, 10), epsilon = NULL){
  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)

  X <- cbind(1,X_train)
  if (is.null(weights) == T) {
    d <- as.matrix(suppressWarnings(pdist(as.matrix(cbind(1,X_new)), as.matrix(X))))
    epsilon_grid <- c(1/2/quantile(d, qtile_grid[which(qtile_grid<=1)])^2,
                      1/2/(max(d)*qtile_grid[which(qtile_grid>1)])^2)
    pred <- matrix(nrow = nrow(X_new), ncol = n_l*n_q)
    for (j in 1:n_q) {
      ker <- exp(-epsilon_grid[j] * d^2)
      weights <- sweep(t(ker), 1, pi_Z, FUN = "*")
      for (i in 1:n_l) {
        lambda = lambda_grid[i]
        beta <- weights %>% as_tibble %>%
          reframe(across(.col = everything(),
                         .fns = ~ginv(t(X)%*%diag(.x)%*%X+diag(lambda,ncol(X)))%*%t(X)%*%diag(.x)%*%Y))
        pred[,n_l*(j-1)+i] <- colSums(t(cbind(1, X_new))*as.matrix(beta))
      }
    }

  } else {
    pred <- matrix(nrow = nrow(X_new), ncol = length(lambda_grid))
    for (i in 1:length(lambda_grid)) {
      lambda = lambda_grid[i]
      beta <- weights %>% as_tibble %>%
        reframe(across(.col = everything(),
                       .fns = ~ginv(t(X)%*%diag(.x)%*%X+diag(lambda,ncol(X)))%*%t(X)%*%diag(.x)%*%Y))
      pred[,i] <- colSums(t(cbind(1, X_new))*as.matrix(beta))
    }
  }

  return(list(coef = beta, pred = pred,
              weights = weights,
              lambda_grid = lambda_grid,
              epsilon_grid = epsilon_grid))
}







##### Import Data #####
load("~/data80.RData")
#### CATE Estimation ####
n <- nrow(data80)
Y <- pull(data80, WEEKSM)
X <- as.matrix(data80[,4:9])
A <- pull(data80, KID3)
Z <- pull(data80, SEXKID)
## Estimation of nuisance
#pi_Zhat
reg_forest_Z.X <- regression_forest(X,Z,tune.parameters = "all")
Zhat <- reg_forest_Z.X$predictions[,1]
pi_Zhat <- (Zhat+1)/2
rm("reg_forest_Z.X")

#conditional means
XZ <- data.frame(X,Z)
reg_forest_Y.XZ <- regression_forest(XZ,Y,tune.parameters = "all")
reg_forest_A.XZ <- regression_forest(XZ,A,tune.parameters = "all")
Yp <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=1))$predictions
Ym <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=-1))$predictions
Am <- predict(reg_forest_A.XZ, newdata = data.frame(X,Z=-1))$predictions
rm("reg_forest_Y.XZ")
rm("reg_forest_A.XZ")

#Estimation of delta(x): CATE of Z on A
cau_forest_delta <- causal_forest(X,(A+1)/2,(Z+1)/2,tune.parameters = "all")
Dhat_CF <- cau_forest_delta$predictions

#augments
aug_CF <- augments(Y,A,Z,Dhat = Dhat_CF, pi_Zhat, Yp = Yp, Ym = Ym, Am = Am)

###CATE (A on Y) estimation using IVDL
fit_IVDL_DCF <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_CF, g=0,lambda_grid = 0)
cate_tr_IVDL_DCF <- cbind(1, X)%*%fit_IVDL_DCF$coef


#predictions
pred <- predict(fit_IVDLCF_DCF, estimate.variance = F)

#combine predictions with the covariates
res <- pred %>% as_tibble %>%
  bind_cols(X) %>%
  mutate(RACEM = RACEM1 + 2*RACEM2,
         INCOMED = INCOMED/1000,
         .after = GRADEM,
         .keep = "unused") %>%
  mutate(AGEMFB = AGEQMFB/4,
         AGEM = AGEQM/4,
         .before = GRADEM,
         .keep = "unused")

#### Figure 2 ####
# We using variable importance for tree splitting as criteria and eventually
# picked 3 variables to grow a tree, which is used to identify subgroups.
set.seed(100)
modtreeV3 <- rpart(
  formula = predictions ~ AGEM + AGEMFB + INCOMED,
  data    = res,
  method  = "anova",
  control = rpart.control(cp = 0.022)
)

rpart.plot(modtreeV3)

#### Figure 3 ####
res %>%
  mutate(leafnode = factor(modtreeV3$where),
         .keep = "unused") %>%
  filter(leafnode==5|leafnode==8|leafnode==9) %>%
  pivot_longer(cols = c(AGEM, AGEMFB, INCOMED),
               names_to = "variables",
               values_to = "values") %>%
ggplot( aes(x=predictions, fill=leafnode)) +
  geom_histogram(bins = 60, alpha=0.7, position = 'identity') +
  scale_fill_discrete(name = "Subgroups",labels = c("1","4","5")) +
  labs(x = "Estimated CATE")

#### Figure 4 ####

# In order for a better visualization, we need to sparse the estimatied points.
# We randomly chose 3000 fitted points and construct a 3D plot.
sample_n(res,size = 3000)
plot_ly(sample_n(res,size = 3000), x = ~AGEM, y = ~AGEMFB, z = ~INCOMED,
        marker = list(size = 3, opacity = 0.7,
                      color = ~predictions,
                      # colorscale = "Reds",
                      showscale = TRUE,
                      colorbar = list(title = list(text = 'Estimated CATE'),
                                      len = 0.5,
                                      x= 0.8))) %>%
  add_markers() %>%
  layout(scene = list(zaxis = list(title = 'Income of Dad'),
                      xaxis = list(title = list(text='Age of Mom')),
                      yaxis = list(title = list(text='Age of Mom at First Birth')))
  )

