Estimating Effect of Counterfactual Treatment Regimes in Longitudinal Data with LTMLE

Author

Sky Qiu and Toru Shirakawa

Introduction

In this coding example, we demonstrate how to estimate the effect of a user-specified counterfactual treatment regime (or any causal contrasts of two regimes) in a longitudinal data setting using the ltmle package in R.

Loading required packages

The ltmle R package implements the longitudinal sequential-regression based TMLE. It supports evaluations of static or dynamic treatment regimes. The package is available on CRAN.

library(ltmle)
library(purrr)

Data and statistical model

We will use a synthetic data included in the dgps folder that mimics a longitudinal study with \(K=4\) time points. The data contains 5000 patients.

Let’s first inspect the example data:

data <- read.csv("data/data_ltmle.csv")
C_cols <- paste0("C", 1:4)
walk(C_cols, function(.x) {
  data[[.x]] <<- BinaryToCensoring(is.uncensored = data[[.x]])
})
head(data)
           W1 W2         L1 A1         C1 Y1         L2 A2         C2 Y2
1 -0.56047565  0 -1.5571097  0 uncensored  0 -2.2444741  0 uncensored  0
2 -0.23017749  0 -0.2616424  1 uncensored  0  0.2115774  0 uncensored  0
3  1.55870831  0 -0.1587545  1 uncensored  0  2.2445080  1 uncensored  0
4  0.07050839  1  0.6537786  0 uncensored  0  2.6829173  0   censored NA
5  0.12928774  1  0.8111676  0 uncensored  1 -0.4408271  1 uncensored  1
6  1.71506499  1  2.4020889  1 uncensored  0  3.1421296  1 uncensored  1
         L3 A3         C3 Y3         L4 A4         C4 Y4
1 -1.471409  0 uncensored  0 -3.3924447  1 uncensored  0
2  0.912274  0 uncensored  0 -0.5059628  1 uncensored  0
3  1.780963  1 uncensored  1  2.8658147  1   censored NA
4        NA NA       <NA> NA         NA NA       <NA> NA
5  2.263297  1 uncensored  1  1.9901538  1   censored NA
6  3.816034  1   censored NA         NA NA       <NA> NA

In the example data, we have four time points. At each time point \(t\), we observe a continuous time-varying covariate \(L_t\), a binary treatment variable \(A_t\), a binary censoring indicator \(C_t\), and a binary outcome variable \(Y_t\). In addition, there are two baseline covariates, \(W_1\) and \(W_2\). The observed data structure is as follows: \[ O=(W=(W_1,W_2),L_1,A_1,C_1,Y_1,\dots,L_4,A_4,C_4,Y_4)\sim P_0\in\mathcal{M}. \] The assumed temporal ordering of the nodes is: \[ L_t\rightarrow A_t\rightarrow C_t\rightarrow Y_t. \] This data structure and time ordering could arise, for instance, in a longitudinal diabetes management program where patients are followed quarterly over one year (four visits in total). At each visit, the patient’s blood pressure is first measured \((L_t)\), after which the clinician decides whether to intensify diabetes treatment \((A_t)\). Following the visit, the patient is asked to complete a laboratory test for their HbA1c level, whose result determines whether diabetes control is achieved \((Y_t)\). The two baseline covariates include the patient’s age (W1) and an indicator of whether they have a comorbid condition such as cardiovascular disease \((W_2)\).

We assume that the physician’s treatment decision at each visit depends only on the baseline covariates and the patient’s current and most recent blood pressure readings. The censoring variable \((C_t)\) indicates whether the patient was lost to follow-up after visit \(t\).

# define nodes
K <- 4
Anodes <- paste0("A", 1:K)
Cnodes <- paste0("C", 1:K)
Lnodes <- paste0("L", 1:K)
Ynodes <- paste0("Y", 1:K)

Estimating the causal contrasts of two static treatment regimes using ltmle

Now, suppose we are interested in estimating the mean counterfactual outcome at the last time point (i.e., \(Y_4\)) if everyone had received treatment at all time points (i.e., \(A_1=A_2=A_3=A_4=1\)), compared to if everyone had received control at all time points (i.e., \(A_1=A_2=A_3=A_4=0\)). Define \(g^{\star,1}\) as the static treatment regime that assigns treatment at all time points with probability 1, and \(g^{\star,0}\) as the static treatment regime that assigns control at all time points with probability 1. Then, our target estimand is given by: \[ E[Y_4^{g^{\star,1}}-Y_4^{g^{\star,0}}]. \] Under standard causal assumptions (sequential randomization and positivity), the target estimand is identified by the G-computation formula as the mean difference in \(Y_4\) under the two post-intervention distributions: \[ E_{P_{Q,g^{\star,1}}}[Y_4]-E_{P_{Q,g^{\star,0}}}[Y_4]. \] Let’s estimate this target using the ltmle() function. First, we need to encode our knowledge on the statistical model. In this case, we do have knowledge on the treatment assignment mechanism at each time point. That is, we know that the doctor only considers baseline covariates \((W_1,W_2)\), the most recent \((L_{t-1})\) and the current \((L_t)\) time-varying covariates when making treatment decisions at a current time-point. We can encode this knowledge using the gform argument in ltmle():

# encode knowledge on conditional independence
gform <- c(
  A1 = "A1 ~ W1 + W2 + L1",
  C1 = "C1 ~ W1 + W2 + L1 + A1",
  A2 = "A2 ~ W1 + W2 + L1 + L2",
  C2 = "C2 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2",
  A3 = "A3 ~ W1 + W2 + L2 + L3",
  C3 = "C3 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2 + Y2 + L3 + A3",
  A4 = "A4 ~ W1 + W2 + L3 + L4",
  C4 = "C4 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2 + Y2 + L3 + A3 + Y3 + L4 + A4"
)

Next, we need to define the counterfactual treatment regime of interest. Here, we are interested in the contrast between two static regimes: \(A_1=A_2=A_3=A_4=1\) and \(A_1=A_2=A_3=A_4=0\):

abar <- list(rep(1, K), rep(0, K))
abar
[[1]]
[1] 1 1 1 1

[[2]]
[1] 0 0 0 0

Now, we are ready to run the ltmle():

fit_ltmle <- ltmle(data = data,
                   Anodes = Anodes,
                   Cnodes = Cnodes,
                   Lnodes = Lnodes,
                   Ynodes = Ynodes,
                   abar = abar,
                   survivalOutcome = FALSE,
                   gform = gform)
Qform not specified, using defaults:
formula for Y1:
Q.kplus1 ~ W1 + W2 + L1 + A1
formula for Y2:
Q.kplus1 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2
formula for Y3:
Q.kplus1 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2 + Y2 + L3 + A3
formula for Y4:
Q.kplus1 ~ W1 + W2 + L1 + A1 + Y1 + L2 + A2 + Y2 + L3 + A3 +     Y3 + L4 + A4
Estimate of time to completion: < 1 minute

Finally, we can summarize the results:

summary(fit_ltmle)
Estimator:  tmle 
Call:
ltmle(data = data, Anodes = Anodes, Cnodes = Cnodes, Lnodes = Lnodes, 
    Ynodes = Ynodes, survivalOutcome = FALSE, gform = gform, 
    abar = abar)

Treatment Estimate:
   Parameter Estimate:  0.75182 
    Estimated Std Err:  0.041583 
              p-value:  <2e-16 
    95% Conf Interval: (0.67031, 0.83332) 

Control Estimate:
   Parameter Estimate:  0.45203 
    Estimated Std Err:  0.05186 
              p-value:  <2e-16 
    95% Conf Interval: (0.35038, 0.55367) 

Additive Treatment Effect:
   Parameter Estimate:  0.29979 
    Estimated Std Err:  0.066472 
              p-value:  6.4831e-06 
    95% Conf Interval: (0.16951, 0.43007) 

Relative Risk:
   Parameter Estimate:  1.6632 
  Est Std Err log(RR):  0.12736 
              p-value:  6.4839e-05 
    95% Conf Interval: (1.2958, 2.1348) 

Odds Ratio:
   Parameter Estimate:  3.6723 
  Est Std Err log(OR):  0.30578 
              p-value:  2.0991e-05 
    95% Conf Interval: (2.0168, 6.6867) 

The true estimand of this DGP is 0.34, which is indeed close to our estimate.

DeepLTMLE

The python package dltmle implements the DeepLTMLE.

import dltmle
import numpy as np

W, L, A, C, Y = dltmle.example_dgp(np.random.default_rng(0), 1000, 10)

hparams_candidates = {
    'dim_emb': [8, 16],
    'dim_emb_time': [4, 8],
    'dim_emb_type': [4, 8],
    'hidden_size': [8, 16, 32],
    'num_layers': [1, 2, 4],
    'nhead': [2, 4],
    'dropout': [0, 0.1, 0.2],
    'learning_rate': [1e-3, 5e-4, 1e-4, 5e-5],
    'alpha': [0.05, 0.1, 0.5, 1],
    'beta': [0.05, 0.1, 0.5, 1],
    'max_epochs': [100],
    'batch_size': [64],
}

hparams = dltmle.tune(0, hparams_candidates, W, L, A, C, Y)

psi_0 = dltmle.fit(0, hparams, W, L, A, C, Y, np.zeros_like(A))
psi_1 = dltmle.fit(0, hparams, W, L, A, C, Y, np.ones_like(A))

print('mean counterfactual outcome under a = 0', psi_0)
print('mean counterfactual outcome under a = 1', psi_1)
print('ATE (risk difference)', psi_1 - psi_0)
print('risk ratio', psi_1 / psi_0)
print('odds ratio', (psi_1 / (1 - psi_1)) / (psi_0 / (1 - psi_0)))