9 min read

ATE Estimation for Count Data

This notebook is a continuation of the previous notebook on ATE estimation for binary data with logistic regression based on the sequence of (great!) posts by Solomon Kurz. In this notebook, we will focus on count data. We reproduce in python an example presented in the post Causal inference with count regression by Solomon Kurz. Our intention is to simply show how to port these type of model to bambi. In addition, as in the previous post, we compare the ATE estimation with a simple linear regression model.

Prepare Notebook

import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
seed: int = sum(map(ord, "logistic"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Read Data

The data used in this example is the epilepsy data set available in the bmrs package. From the package documentation:

Breslow and Clayton (1993) analyze data initially provided by Thall and Vail (1990) concerning seizure counts in a randomized trial of anti-convulsant therapy in epilepsy. Covariates are treatment, 8-week baseline seizure counts, and age of the patients in years.

Solomon’s blog give us more detail on the data set:

The original study was designed to see how well the anti-convulsant drug progabide reduced epileptic seizures in persons who experienced simple or complex partial seizures. After a baseline assessment, participants were randomized into treatment groups in which they received either progabide (Trt == 1) or a placebo (Trt == 0). The primary outcome variable in the data is count, which is the number of epileptic seizures each participant had had since their last visit. Following their baseline assessments, participants came back every two weeks for assessments, which are denoted by the factor variable visit.

To reproduce the analysis we got the data from the bmrs package and save it in a csv file.

raw_df = pd.read_csv("../data/brms_epilepsy.csv")

df = raw_df.copy()
df.columns = df.columns.str.lower()

df.head()
age base trt patient visit count obs zage zbase
0 31 11 0 1 1 5 1 0.424995 -0.757173
1 30 11 0 2 1 3 2 0.265284 -0.757173
2 25 6 0 3 1 2 3 -0.533274 -0.944403
3 36 8 0 4 1 4 4 1.223553 -0.869511
4 22 66 0 5 1 7 5 -1.012408 1.302363
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 236 entries, 0 to 235
Data columns (total 9 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   age      236 non-null    int64  
 1   base     236 non-null    int64  
 2   trt      236 non-null    int64  
 3   patient  236 non-null    int64  
 4   visit    236 non-null    int64  
 5   count    236 non-null    int64  
 6   obs      236 non-null    int64  
 7   zage     236 non-null    float64
 8   zbase    236 non-null    float64
dtypes: float64(2), int64(7)
memory usage: 16.7 KB

Let’s star by replicating one of the initial plots from the post where we see the count per visit for all the patients split by the treatment group.

g = sns.relplot(
    data=df,
    x="visit",
    y="count",
    col="trt",
    kind="line",
    units="patient",
    estimator=None,
    color="C0",
    alpha=0.5,
)

Next we get some basic statistics about the target variable count split by the treatment group trt and the visit.

df.groupby(["trt", "visit"]).agg({"count": [np.mean, np.var, np.min, np.max]})
count
mean var amin amax
trt visit
0 1 9.357143 102.756614 0 40
2 8.285714 66.656085 0 29
3 8.714286 213.322751 0 76
4 7.964286 58.183862 0 29
1 1 8.580645 332.718280 0 102
2 8.419355 140.651613 0 65
3 8.129032 193.049462 0 72
4 6.709677 126.879570 0 63

Data Preparation

Fot the sake of this analysis we filter out the last visit.

data = df.query("visit == 4").reset_index(drop=True)

data.head()
age base trt patient visit count obs zage zbase
0 31 11 0 1 4 3 178 0.424995 -0.757173
1 30 11 0 2 4 3 179 0.265284 -0.757173
2 25 6 0 3 4 5 180 -0.533274 -0.944403
3 36 8 0 4 4 4 181 1.223553 -0.869511
4 22 66 0 5 4 21 182 -1.012408 1.302363

We can verify the number of participants in each group.

data.groupby("trt").agg({"patient": "nunique"})
patient
trt
0 28
1 31

Difference in Means

As in the previous post, since this data set comes from a randomized controlled trial (RCT), we can use the difference in means (DIM) as a baseline for the ATE estimation.

diff_means = (
    data.query("trt == 1")["count"].mean() - data.query("trt == 0")["count"].mean()
)

print(f"Sample ATE: {diff_means:.3f}")
Sample ATE: -1.255
fig, ax = plt.subplots()
sns.histplot(data=data, x="count", hue="trt", stat="density", ax=ax)
sns.kdeplot(data=data, x="count", hue="trt", clip=(0, None), ax=ax)
ax.set(title="Distribution of Counts by Treatment Group")

We would like to estimate our uncertainty regarding this estimation.

Negative Binomial Model

To model count data a natural likelihood function is the negative binomial distribution. We could have used a Poisson distribution, but looking into the statistics above we expect some overdispersion in the data.

We define our model in the same manner as in Solomon’s post:

negative_binomial_priors = {
    "Intercept": bmb.Prior("Normal", mu=np.log(14), sigma=1),
    "trt": bmb.Prior("Normal", mu=0, sigma=0.5),
    "alpha": bmb.Prior("Gamma", alpha=0.01, beta=0.01),
}

negative_binomial_model = bmb.Model(
    formula="count ~ trt",
    data=data,
    family="negativebinomial",
    link="log",
    priors=negative_binomial_priors,
)

negative_binomial_model
       Formula: count ~ trt
        Family: negativebinomial
          Link: mu = log
  Observations: 59
        Priors: 
    target = mu
        Common-level effects
            Intercept ~ Normal(mu: 2.6391, sigma: 1.0)
            trt ~ Normal(mu: 0.0, sigma: 0.5)
        
        Auxiliary parameters
            count_alpha ~ Gamma(alpha: 0.01, beta: 0.01)
negative_binomial_model.build()
negative_binomial_model.graph()

Before fitting the model, let’s take a look at the priors. One of the aspects I like the most about Solomon’s blogs are the discussion about prior selection. In this case:

By centering the prior for (the intercept) \(\beta_{0}\) on \(\log(14)\), I’m betting those in the control (placebo) condition will report about \(1\) seizure at day (i.e., \(14\) during the last 2-week period). But since I know very little about seizures, and even less about seizure medication trials, I’m very insure about that prior, which is why I’ve set the prior standard deviation to \(1\).

Observe that we are using a \(\log\) link function to make the mean of the negative binomial distribution positive. This explains the \(\log\) transformation of the prior mean of intercept \(\beta_{0}\).

In addition,

The \(\text{Gamma}(0.01,0.01)\) prior for \(\alpha\) is the default in the brms package.

For more tips and trick regarding prior specification for negative binomial models, please look into this Stan guide.

negative_binomial_model.plot_priors(draws=10_000, figsize=(15, 7), random_seed=rng)
plt.gcf().suptitle("Negative Binomial Model Priors", fontsize=16)

Now we proceed to fit the model:

negative_binomial_idata = negative_binomial_model.fit(
    draws=4_000, chains=5, nuts_sampler="numpyro", random_seed=rng
)

negative_binomial_model.predict(idata=negative_binomial_idata, kind="pps")

Next, we now look into the summary table and the trace plots:

az.summary(data=negative_binomial_idata, var_names=["Intercept", "trt", "count_alpha"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 2.079 0.190 1.732 2.446 0.001 0.001 19192.0 13083.0 1.0
trt -0.132 0.246 -0.594 0.327 0.002 0.002 18391.0 14179.0 1.0
count_alpha 1.038 0.222 0.645 1.453 0.002 0.001 17646.0 15257.0 1.0
axes = az.plot_trace(
    data=negative_binomial_idata,
    var_names=["Intercept", "trt", "count_alpha"],
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
)
plt.gcf().suptitle("Negative Binomial Regression Model - Trace", fontsize=16)

As in explained in the previous post, to recover the ATE estimation we need to compute the difference between the posterior mean of the treatment group and the posterior mean of the control group taking the inverse of the link function into account. In this case we need to take the exponential. If we denote by \(\beta_{0}\) the intercept and \(\beta_1\) the coefficient of the treatment group variable trt, then the posterior mean of the treatment group is given by \(\exp(\beta_{0} + \beta_{1})\) and the posterior mean of the control group is given by \(\exp(\beta_{0})\). Therefore, the ATE estimation is given by

\[\exp(\beta_{0} + \beta_{1}) - \exp(\beta_{1})\]

We can easily compute this from the trace:

ate_samples = np.exp(
    negative_binomial_idata["posterior"]["Intercept"]
    + negative_binomial_idata["posterior"]["trt"]
) - np.exp(negative_binomial_idata["posterior"]["Intercept"])

fig, ax = plt.subplots()
az.plot_posterior(data=ate_samples, ref_val=diff_means, ax=ax)
ax.set(title="Negative Binomial Model - ATE", xlabel="ATE")

We see how the ATE posterior distribution is centered around the sample ATE computed from the difference in means.

Linear Regression Model (OLS)

We now run a similar model but using a linear regression model with an identity link function.

gaussian_model_priors = {
    "Intercept": bmb.Prior("Normal", mu=14, sigma=3),
    "trt": bmb.Prior("Normal", mu=0, sigma=5),
    "sigma": bmb.Prior("Exponential", lam=1 / 10),
}
gaussian_model = bmb.Model(
    formula="count ~ trt",
    data=data,
    family="gaussian",
    link="identity",
    priors=gaussian_model_priors,
)

gaussian_model
       Formula: count ~ trt
        Family: gaussian
          Link: mu = identity
  Observations: 59
        Priors: 
    target = mu
        Common-level effects
            Intercept ~ Normal(mu: 14.0, sigma: 3.0)
            trt ~ Normal(mu: 0.0, sigma: 5.0)
        
        Auxiliary parameters
            count_sigma ~ Exponential(lam: 0.1)
gaussian_model.build()
gaussian_model.graph()

We fit the model and look into the results and diagnostics.

gaussian_idata = gaussian_model.fit(
    draws=4_000, chains=5, nuts_sampler="numpyro", random_seed=rng
)

gaussian_model.predict(idata=gaussian_idata, kind="pps")
az.summary(data=gaussian_idata, var_names=["Intercept", "trt", "count_sigma"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 8.877 1.693 5.712 12.051 0.012 0.009 18427.0 13652.0 1.0
trt -0.999 2.265 -5.148 3.367 0.017 0.015 18344.0 14465.0 1.0
count_sigma 9.879 0.943 8.165 11.662 0.007 0.005 18528.0 14232.0 1.0
axes = az.plot_trace(
    data=gaussian_idata,
    var_names=["Intercept", "trt", "count_sigma"],
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
)
plt.gcf().suptitle("Linear Regression Model - Trace", fontsize=16)
fig, ax = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=True, figsize=(10, 8), layout="constrained"
)
az.plot_posterior(data=ate_samples, ref_val=diff_means, ax=ax[0])
ax[0].set(title="Logistic Regression Model")
az.plot_posterior(data=gaussian_idata, var_names=["trt"], ref_val=diff_means, ax=ax[1])
ax[1].set(title="Negative Binomial Regression Model", xlabel="ATE")
fig.suptitle("ATE Comparison", y=1.05, fontsize=16)

The resulting ATE posterior distributions are very similar. Note however that the posterior means does not match the sample ATE computed from the difference in means. Still, in view of the posterior distribution uncertainty it looks compatible.

Which model to choose? As discussed in the previous post, conceptually the negative binomial model seems is more appropriate. Moreover, by running some experiment I saw that the linear regression model is very sensitive to the prior specification of the group variable trt. It is actually easier to reason about the priors in the negative binomial model as one is really trying to model the underlying distribution. This is yet another reason to prefer the negative binomial model.

fig, ax = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), layout="constrained"
)
az.plot_ppc(data=negative_binomial_idata, num_pp_samples=1_000, ax=ax[0])
ax[0].set(title="Negative Binomial Model", xlabel="count")
az.plot_ppc(data=gaussian_idata, num_pp_samples=1_000, ax=ax[1])
ax[1].set(title="Linear Regression Model", xlabel="count")
fig.suptitle("Posterior Predictive Check", y=1.05, fontsize=16)