Bulding a Diverse Super Learner Library with Candidate HAL Learners

Author

Sky Qiu

Introduction

This tutorial shows how to build a discrete Super Learner (SL) library comprised of multiple HAL candidates—each with different specifications—and then plug these learners into a TMLE pipeline for ATE estimation. The aim is to create diversity in the HAL library so that cross-validation can select the best specification for each function one tries to estimate.

Load required packages

The sl3 R package implements the Super Learner and provides a convenient interface for defining custom learners, including HAL.

library(hal9001)
library(sl3) # super learner
library(origami) # for V-fold cross-validation
library(tmle)
options(sl3.verbose = TRUE)

Data

Let’s load the example data set:

set.seed(123)
data <- read.csv("data/data_dsl.csv")
W1 <- data$W1
W2 <- data$W2
W3 <- data$W3
A <- data$A
Y <- data$Y
folds <- make_folds(n = nrow(data), V = 5)
head(data)
           W1 W2         W3 A         Y
1 -0.56047565  0  1.5384302 0 0.3022093
2 -0.23017749  1 -0.1097103 1 1.3960467
3  1.55870831  0  0.5114708 1 1.2924846
4  0.07050839  1  0.2139580 1 2.9539629
5  0.12928774  1 -0.1861207 1 3.0885005
6  1.71506499  0 -0.1203938 0 2.8793310

HAL super learner for the treatment mechanism

First, let’s build a super learner library for the treatment mechanism \(g(1\mid W)=P(A=1\mid W)\). We can define formula to give HAL specifications, including smoothness order, number of knot-points, and interactions.

# specify formulas
formula_1 <- ~ h(W1,s=0,k=20) + h(W2,s=0)
formula_2 <- ~ h(W1,s=0,k=20) + h(W2,s=0) + h(W1,W2,s=0,k=10)
formula_3 <- ~ h(W1,s=1,k=20) + h(W2,s=1) + h(W1,W2,s=1,k=10)

# define candidate HAL learners
g_hal_1 <- Lrnr_hal9001$new(formula = formula_1)
g_hal_2 <- Lrnr_hal9001$new(formula = formula_2)
g_hal_3 <- Lrnr_hal9001$new(formula = formula_3)

# define discrete SL library
g_lib <- Stack$new(g_hal_1, g_hal_2, g_hal_3)
g_sl <- Lrnr_sl$new(learners = g_lib, 
                    metalearner = Lrnr_cv_selector$new(loss_loglik_binomial))

Next, we define a task for fitting the treatment mechanism.

# define task for fitting the treatment mechanism
task_g <- sl3_Task$new(data = data, 
                       covariates = c("W1", "W2", "W3"),
                       outcome = "A",
                       outcome_type = "binomial",
                       folds = folds)

# fit the treatment mechanism and obtain predictions
fit_g <- g_sl$train(task_g)
g1W <- fit_g$predict(task_g) # P(A=1|W)

HAL super learner for the outcome regression

Now, let’s build a super learner library for the outcome regression \(Q(A,W)=E(Y\mid A,W)\).

# specify formulas
formula_Q_1 <- ~ h(., s=0,k=20) + h(.,., s=0,k=20)
formula_Q_2 <- ~ h(W1,s=0,k=20) + h(W2,s=0) + h(W3,s=0,k=20) + h(A, s=0) +
  h(A,W1,s=0,k=10) + h(A,W3,s=0,k=10)
formula_Q_3 <- ~ h(W1,s=1,k=20) + h(W2,s=1) + h(W3,s=1,k=20) + h(A, s=0) +
  h(A,W1,s=1,k=10) + h(A,W3,s=1,k=10)

# define candidate HAL learners
Q_hal_1 <- Lrnr_hal9001$new(formula = formula_Q_1)
Q_hal_2 <- Lrnr_hal9001$new(formula = formula_Q_2)
Q_hal_3 <- Lrnr_hal9001$new(formula = formula_Q_3)

# define discrete SL library
Q_lib <- Stack$new(Q_hal_1, Q_hal_2, Q_hal_3)
Q_sl <- Lrnr_sl$new(learners = Q_lib, 
                    metalearner = Lrnr_cv_selector$new(loss_squared_error))

Similarly, we define tasks for fitting the outcome regression. Note that for the outcome regression, we need to create three tasks: one for fitting the observed outcome regression \(E(Y\mid A,W)\), and two counterfactual tasks for \(A=1\) and \(A=0\) to obtain predictions under each treatment level.

# define task for fitting the outcome regression (and counterfactual tasks)
task_Q <- sl3_Task$new(data = data,
                       covariates = c("A", "W1", "W2", "W3"),
                       outcome = "Y",
                       outcome_type = "continuous",
                       folds = folds)
data_A1 <- data; data_A1$A <- 1; data_A0 <- data; data_A0$A <- 0
task_Q1 <- sl3_Task$new(data = data_A1,
                        covariates = c("A", "W1", "W2", "W3"),
                        outcome = "Y",
                        outcome_type = "continuous",
                        folds = folds)
task_Q0 <- sl3_Task$new(data = data_A0,
                        covariates = c("A", "W1", "W2", "W3"),
                        outcome = "Y",
                        outcome_type = "continuous",
                        folds = folds)

# fit the outcome regression and obtain predictions
fit_Q <- Q_sl$train(task_Q) 
QAW <- fit_Q$predict(task_Q)  # E(Y|A,W)
Q1W <- fit_Q$predict(task_Q1) # E(Y|A=1,W)
Q0W <- fit_Q$predict(task_Q0) # E(Y|A=0,W)
mean(Q1W-Q0W)
[1] 0.6876823

TMLE for the ATE

Finally, let’s supply the nuisance parameter estimate g1W and initial estimates Q1W and Q0W to the tmle() function from the tmle R package to compute the targeted estimates.

tmle_obj <- tmle(W = data[, c("W1", "W2", "W3")],
                 Y = data$Y,
                 A = data$A,
                 Q = as.matrix(data.frame(Q0W=Q0W, Q1W=Q1W)),
                 g1W = g1W,
                 family = "gaussian",
                 prescreenW.g   = FALSE,
                 evalATT = FALSE)
summary(tmle_obj)
 Initial estimation of Q
     Procedure: user-supplied values
     Cross-validated R squared :  0.3904 

 Estimation of g (treatment mechanism)
     Procedure: user-supplied values, ensemble

 Estimation of g.Z (intermediate variable assignment mechanism)
     Procedure: No intermediate variable 

 Estimation of g.Delta (missingness mechanism)
     Procedure: No missingness, ensemble

 Bounds on g: (0.036, 1) 

 Bounds on g for ATT/ATC: (0.036, 0.964) 

 Marginal Mean under Treatment (EY1)
   Parameter Estimate:  1.7269
   Estimated Variance:  0.0045802
              p-value:  <2e-16
    95% Conf Interval:  (1.5942, 1.8595)

 Marginal Mean under Comparator (EY0)
   Parameter Estimate:  1.029
   Estimated Variance:  0.0063751
              p-value:  <2e-16
    95% Conf Interval:  (0.87253, 1.1855)

 Additive Effect
   Parameter Estimate:  0.69787
   Estimated Variance:  0.0088281
              p-value:  1.1069e-13
    95% Conf Interval:  (0.51372, 0.88202)

The true ATE is 0.72 for this DGP.