library(hal9001)
library(glmnet)
library(plotly) # for interactive plotsEstimating Conditional Average Treatment Effect using HAL
Introduction
In this tutorial, we demonstrate how to estimate the conditional average treatment effect (CATE) function using HAL.
Problem setup
We consider the following setup. The observed data consist of \(n\) i.i.d. copies of \[ O=(W,A,Y)\sim P_0\in\mathcal{M}, \] where
\(W \in \mathbb{R}^d\) are baseline covariates,
\(A \in \{0,1\}\) is a binary treatment indicator, and
\(Y \in \mathbb{R}\) is an outcome of interest.
The statistical model \(\mathcal{M}\) is nonparametric.
Our parameter of interest is the CATE function: \[ \tau_0(W)=E_0(Y\mid A=1,W)-E_0(Y\mid A=0,W), \] which represents the expected effect of the treatment for an individual with covariates \(W\).
In this example, we use the R-loss \(L_{m_0,g_0}(\tau)\) indexed by two nuisance parameters \(m\) and \(g\) for learning CATE: \[ L_{m,g}(\tau)=\left( Y-m(W)-(A-g(1\mid W))\tau\right)^2, \] where \[ m(W)=E(Y\mid W)\quad\text{and}\quad g(1\mid W)=P(A=1\mid W). \] This loss function is double robust in the sense that we have \(\tau_0=\arg\min_{\tau}P_0 L_{m,g}(\tau)\) if either \(m=m_0\) or \(g=g_0\).
Load required packages
We first load the required packages.
Data generating process
We simulate data from the following data generating process:
\(W=(W_1,W_2,W_3)\sim \text{Uniform}(-1,1)\);
\(A\sim \text{Bernoulli}(\text{expit}(-0.25W_1+W_2))\);
\(Y\sim N(\mu(A,W),1)\), where \(\mu(A,W)=-0.5+W_1+0.5W_2+0.3W_3+A[0.5W_1+\sin(2\pi W_2)]\).
set.seed(1234)
data <- read.csv("data/data_cate.csv")
W <- data[, grep("W", colnames(data))]
A <- data$A
Y <- data$YStep 1: Estimate nuisance functions using HAL
First, we use HAL to estimate the nuisance function \(m(W)=E(Y\mid W)\). In this tutorial, we will try manually enumerate the list of HAL basis functions, compute the HAL design matrix, and then fit the LASSO. Of course, one could also call the fit_hal() function directly.
# enumerate basis list
basis_list <- enumerate_basis(x = as.matrix(W),
max_degree = 2L,
smoothness_orders = 0L)
# make HAL design matrix
phi_W <- make_design_matrix(X = as.matrix(W), blist = basis_list)
# fit HAL
m_fit <- cv.glmnet(x = phi_W,
y = Y,
family = "gaussian",
alpha = 1)
m <- as.numeric(predict(m_fit, newx = phi_W, s = "lambda.min",
type = "response"))Next, we estimate the other nuisance function \(g(1\mid W)=P(A=1\mid W)\):
# fit HAL
g_fit <- cv.glmnet(x = phi_W,
y = A,
family = "binomial",
alpha = 1)
g <- as.numeric(predict(g_fit, newx = phi_W, s = "lambda.min",
type = "response"))Step 2: Estimate the CATE under the R-loss
Recall that our loss function is: \[ L_{m,g}(\tau)=\left( Y-m(W)-(A-g(1\mid W))\tau\right)^2, \] which can also be expressed as: \[ L_{m,g}(\tau)=(A-g(1\mid W))^2\left(\frac{Y-m(W)}{A-g(1\mid W)}-\tau\right)^2. \] This suggests a simple way to estimate \(\tau_0\) via a weighted regression of the pseudo-outcome \((Y-m(W))/(A-g(1\mid W))\) on basis functions of \(W\), with weights given by \((A-g(1\mid W))^2\).
pseudo_outcome <- (Y-m)/(A-g)
pseudo_weights <- (A-g)^2
tau_fit <- cv.glmnet(x = phi_W,
y = pseudo_outcome,
weights = pseudo_weights,
family = "gaussian",
nfolds = 5,
alpha = 1)
tau <- as.numeric(predict(tau_fit, newx = phi_W, s = "lambda.min", type = "response"))Plot the estimated CATE function
Now, let’s plot a snapshot of the estimated CATE function as a function of \(W_1\) and \(W_2\), integrating out \(W_3\), and compare it with the true CATE function.
# CATE on W1, W2
basis_list_W1W2 <- enumerate_basis(x = data[, c("W1", "W2")],
max_degree = 2,
smoothness_orders = 0)
phi_W1W2 <- make_design_matrix(X = as.matrix(data[, c("W1", "W2")]),
blist = basis_list_W1W2)
cate_W1W2_fit <- cv.glmnet(x = phi_W1W2,
y = tau,
family = "gaussian",
alpha = 1)
# define a grid of W1 and W2 values
W1_vals <- seq(-1, 1, 0.1)
W2_vals <- seq(-1, 1, 0.1)
grid <- expand.grid(W1 = W1_vals,
W2 = W2_vals)
# predict CATE on the grid and obtain true CATE on the grid
grid_hal_design <- make_design_matrix(X = as.matrix(grid),
blist = basis_list_W1W2)
grid$pred_cate <- as.numeric(predict(cate_W1W2_fit, newx = grid_hal_design,
s = "lambda.min", type = "response"))
pred_cate_mat <- matrix(grid$pred_cate,
nrow = length(W1_vals),
ncol = length(W2_vals),
byrow = TRUE)
grid$true_cate <- sim_data_cate(n, grid = grid[, c("W1", "W2")])
true_cate_mat <- matrix(grid$true_cate,
nrow = length(W1_vals),
ncol = length(W2_vals),
byrow = TRUE)# plot the estimated CATE surface
plt_pred <- plot_ly(x = W1_vals, y = W2_vals, z = pred_cate_mat,
colorscale = "Greens") %>%
add_surface() %>%
layout(title = list(text = "Estimated CATE",
font = list(size = 18, color = "black"),
yref = "container",
y = 0.98),
scene = list(xaxis = list(title = "W1", titlefont = list(size = 14)),
yaxis = list(title = "W2", titlefont = list(size = 14)),
zaxis = list(title = "CATE", titlefont = list(size = 14),
range = c(-1.5, 1.5)),
camera = list(eye = list(x = -1.5, y = 1.5, z = 0.8))),
legend = list(x = 0.8, y = 1, font = list(size = 12)))
plt_pred
# plot the true CATE surface
plt_true <- plot_ly(x = W1_vals, y = W2_vals, z = true_cate_mat,
colorscale = "Reds") %>%
add_surface() %>%
layout(title = list(text = "True CATE",
font = list(size = 18, color = "black"),
yref = "container",
y = 0.98),
scene = list(xaxis = list(title = "W1", titlefont = list(size = 14)),
yaxis = list(title = "W2", titlefont = list(size = 14)),
zaxis = list(title = "CATE", titlefont = list(size = 14),
range = c(-1.5, 1.5)),
camera = list(eye = list(x = -1.5, y = 1.5, z = 0.8))),
legend = list(x = 0.8, y = 1, font = list(size = 12)))
plt_trueReferences
Nie, Xinkun, and Stefan Wager. “Quasi-oracle estimation of heterogeneous treatment effects.” Biometrika 108.2 (2021): 299-319. [Link]