Estimating Conditional Average Treatment Effect using HAL

Author

Sky Qiu

Introduction

In this tutorial, we demonstrate how to estimate the conditional average treatment effect (CATE) function using HAL.

Problem setup

We consider the following setup. The observed data consist of \(n\) i.i.d. copies of \[ O=(W,A,Y)\sim P_0\in\mathcal{M}, \] where

  • \(W \in \mathbb{R}^d\) are baseline covariates,

  • \(A \in \{0,1\}\) is a binary treatment indicator, and

  • \(Y \in \mathbb{R}\) is an outcome of interest.

The statistical model \(\mathcal{M}\) is nonparametric.

Our parameter of interest is the CATE function: \[ \tau_0(W)=E_0(Y\mid A=1,W)-E_0(Y\mid A=0,W), \] which represents the expected effect of the treatment for an individual with covariates \(W\).

In this example, we use the R-loss \(L_{m_0,g_0}(\tau)\) indexed by two nuisance parameters \(m\) and \(g\) for learning CATE: \[ L_{m,g}(\tau)=\left( Y-m(W)-(A-g(1\mid W))\tau\right)^2, \] where \[ m(W)=E(Y\mid W)\quad\text{and}\quad g(1\mid W)=P(A=1\mid W). \] This loss function is double robust in the sense that we have \(\tau_0=\arg\min_{\tau}P_0 L_{m,g}(\tau)\) if either \(m=m_0\) or \(g=g_0\).

Load required packages

We first load the required packages.

library(hal9001)
library(glmnet)
library(plotly) # for interactive plots

Data generating process

We simulate data from the following data generating process:

  • \(W=(W_1,W_2,W_3)\sim \text{Uniform}(-1,1)\);

  • \(A\sim \text{Bernoulli}(\text{expit}(-0.25W_1+W_2))\);

  • \(Y\sim N(\mu(A,W),1)\), where \(\mu(A,W)=-0.5+W_1+0.5W_2+0.3W_3+A[0.5W_1+\sin(2\pi W_2)]\).

set.seed(1234)
data <- read.csv("data/data_cate.csv")
W <- data[, grep("W", colnames(data))]
A <- data$A
Y <- data$Y

Step 1: Estimate nuisance functions using HAL

First, we use HAL to estimate the nuisance function \(m(W)=E(Y\mid W)\). In this tutorial, we will try manually enumerate the list of HAL basis functions, compute the HAL design matrix, and then fit the LASSO. Of course, one could also call the fit_hal() function directly.

# enumerate basis list
basis_list <- enumerate_basis(x = as.matrix(W),
                              max_degree = 2L,
                              smoothness_orders = 0L)

# make HAL design matrix
phi_W <- make_design_matrix(X = as.matrix(W), blist = basis_list)

# fit HAL
m_fit <- cv.glmnet(x = phi_W, 
                   y = Y,
                   family = "gaussian", 
                   alpha = 1)
m <- as.numeric(predict(m_fit, newx = phi_W, s = "lambda.min", 
                        type = "response"))

Next, we estimate the other nuisance function \(g(1\mid W)=P(A=1\mid W)\):

# fit HAL
g_fit <- cv.glmnet(x = phi_W, 
                   y = A,
                   family = "binomial", 
                   alpha = 1)
g <- as.numeric(predict(g_fit, newx = phi_W, s = "lambda.min", 
                        type = "response"))

Step 2: Estimate the CATE under the R-loss

Recall that our loss function is: \[ L_{m,g}(\tau)=\left( Y-m(W)-(A-g(1\mid W))\tau\right)^2, \] which can also be expressed as: \[ L_{m,g}(\tau)=(A-g(1\mid W))^2\left(\frac{Y-m(W)}{A-g(1\mid W)}-\tau\right)^2. \] This suggests a simple way to estimate \(\tau_0\) via a weighted regression of the pseudo-outcome \((Y-m(W))/(A-g(1\mid W))\) on basis functions of \(W\), with weights given by \((A-g(1\mid W))^2\).

pseudo_outcome <- (Y-m)/(A-g)
pseudo_weights <- (A-g)^2
tau_fit <- cv.glmnet(x = phi_W,
                     y = pseudo_outcome,
                     weights = pseudo_weights,
                     family = "gaussian", 
                     nfolds = 5,
                     alpha = 1)
tau <- as.numeric(predict(tau_fit, newx = phi_W, s = "lambda.min", type = "response"))

Plot the estimated CATE function

Now, let’s plot a snapshot of the estimated CATE function as a function of \(W_1\) and \(W_2\), integrating out \(W_3\), and compare it with the true CATE function.

# CATE on W1, W2
basis_list_W1W2 <- enumerate_basis(x = data[, c("W1", "W2")],
                                   max_degree = 2,
                                   smoothness_orders = 0)
phi_W1W2 <- make_design_matrix(X = as.matrix(data[, c("W1", "W2")]),
                               blist = basis_list_W1W2)
cate_W1W2_fit <- cv.glmnet(x = phi_W1W2, 
                           y = tau,
                           family = "gaussian", 
                           alpha = 1)

# define a grid of W1 and W2 values
W1_vals <- seq(-1, 1, 0.1)
W2_vals <- seq(-1, 1, 0.1)
grid <- expand.grid(W1 = W1_vals,
                    W2 = W2_vals)

# predict CATE on the grid and obtain true CATE on the grid
grid_hal_design <- make_design_matrix(X = as.matrix(grid),
                                      blist = basis_list_W1W2)
grid$pred_cate <- as.numeric(predict(cate_W1W2_fit, newx = grid_hal_design, 
                                     s = "lambda.min", type = "response"))
pred_cate_mat <- matrix(grid$pred_cate, 
                        nrow = length(W1_vals), 
                        ncol = length(W2_vals),
                        byrow = TRUE)
grid$true_cate <- sim_data_cate(n, grid = grid[, c("W1", "W2")])
true_cate_mat <- matrix(grid$true_cate, 
                        nrow = length(W1_vals), 
                        ncol = length(W2_vals),
                        byrow = TRUE)
# plot the estimated CATE surface
plt_pred <- plot_ly(x = W1_vals, y = W2_vals, z = pred_cate_mat, 
                    colorscale = "Greens") %>%
  add_surface() %>% 
  layout(title = list(text = "Estimated CATE", 
                      font = list(size = 18, color = "black"),
                      yref = "container", 
                      y = 0.98),
         scene = list(xaxis = list(title = "W1", titlefont = list(size = 14)),
                      yaxis = list(title = "W2", titlefont = list(size = 14)),
                      zaxis = list(title = "CATE", titlefont = list(size = 14), 
                                   range = c(-1.5, 1.5)),
                      camera = list(eye = list(x = -1.5, y = 1.5, z = 0.8))),
         legend = list(x = 0.8, y = 1, font = list(size = 12)))
plt_pred
# plot the true CATE surface
plt_true <- plot_ly(x = W1_vals, y = W2_vals, z = true_cate_mat, 
                    colorscale = "Reds") %>%
  add_surface() %>% 
  layout(title = list(text = "True CATE", 
                      font = list(size = 18, color = "black"),
                      yref = "container", 
                      y = 0.98),
         scene = list(xaxis = list(title = "W1", titlefont = list(size = 14)),
                      yaxis = list(title = "W2", titlefont = list(size = 14)),
                      zaxis = list(title = "CATE", titlefont = list(size = 14), 
                                   range = c(-1.5, 1.5)),
                      camera = list(eye = list(x = -1.5, y = 1.5, z = 0.8))),
         legend = list(x = 0.8, y = 1, font = list(size = 12)))
plt_true

References

Nie, Xinkun, and Stefan Wager. “Quasi-oracle estimation of heterogeneous treatment effects.” Biometrika 108.2 (2021): 299-319. [Link]