Synthetic control can be considered “the most important innovation in the policy evaluation literature in the last few years” (see The State of Applied Econometrics: Causality and Policy Evaluation by Susan Athey and Guido W. Imbens).
In this notebook we provide an example of how to implement a synthetic control problem in PyMC to answer a “what if this had happened?” type of question in the context of causal inference. We reproduce the results of the example provided in the great book Causal Inference for The Brave and True by Matheus Facure. Specifically, we look into the problem of estimating the effect of cigarette taxation on its consumption presented in Chapter 15 - Synthetic Control. The purpose of this notebook is not to present a general description of this method but rather show how to implement it in PyMC. Therefore, we strongly recommend to look into the book to understand the problem, motivation and details of the method.
Prepare Notebook
import aesara.tensor as at
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import seaborn as sns
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
%load_ext rich
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Read Data
We read the data directly from the book’s repository.
data_path = "https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/smoking.csv"
raw_data_df = pd.read_csv(data_path)
raw_data_df.head()
state | year | cigsale | lnincome | beer | age15to24 | retprice | california | after_treatment | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1970 | 89.800003 | NaN | NaN | 0.178862 | 39.599998 | False | False |
1 | 1 | 1971 | 95.400002 | NaN | NaN | 0.179928 | 42.700001 | False | False |
2 | 1 | 1972 | 101.099998 | 9.498476 | NaN | 0.180994 | 42.299999 | False | False |
3 | 1 | 1973 | 102.900002 | 9.550107 | NaN | 0.182060 | 42.099998 | False | False |
4 | 1 | 1974 | 108.199997 | 9.537163 | NaN | 0.183126 | 43.099998 | False | False |
We are mainly interested in the cigsale
variable as a target and retprice
as control. Also note we already have an indicator tag california
which is true for the state
index \(3\).
# select relevant columns
df = raw_data_df.copy().drop(columns=["lnincome", "beer", "age15to24"])
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1209 entries, 0 to 1208
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 state 1209 non-null int64
1 year 1209 non-null int64
2 cigsale 1209 non-null float64
3 retprice 1209 non-null float64
4 california 1209 non-null bool
5 after_treatment 1209 non-null bool
dtypes: bool(2), float64(2), int64(2)
memory usage: 40.3 KB
EDA
Let’s plot the sales data for California and the other states together.
fig, ax = plt.subplots()
(
df.groupby(["year", "california"], as_index=False)
.agg({"cigsale": np.mean})
.pipe(
(sns.lineplot, "data"),
x="year",
y="cigsale",
hue="california",
marker="o",
ax=ax,
)
)
ax.axvline(
x=1988,
linestyle=":",
lw=2,
color="C2",
label="Proposition 99",
)
ax.legend(loc="upper right")
ax.set(
title="Gap in per-capita cigarette sales (in packs)",
ylabel="Cigarette Sales Trend"
);
Data Processing
We now split the data intro pre and post treatment years. Moreover, we shape the data in such a way that the columns correspond to states (see here for details).
features = ["cigsale", "retprice"]
pre_df = (
df
.query("~ after_treatment")
.pivot(index='state', columns="year", values=features)
.T
)
post_df = (
df
.query("after_treatment")
.pivot(index='state', columns="year", values=features)
.T
)
Finally, we extract some useful information about the data.
idx = 3
y_pre = pre_df[idx].to_numpy()
x_pre = pre_df.drop(columns=idx).to_numpy()
pre_years = pre_df.reset_index(inplace=False).year.unique()
n_pre = pre_years.size
y_post = post_df[idx].to_numpy()
x_post = post_df.drop(columns=idx).to_numpy()
post_years = post_df.reset_index(inplace=False).year.unique()
n_post = post_years.size
k = pre_df.shape[1] - 1
Model
Here is the PyMC model for the synthetic control problem. The only thing to remark is the appearance of the Dirichlet
distribution as prior for the model weights. This ensures the weights are all positive and add up all to one as required.
Remark: Note that the prior parameter a
coincides with the initial point w_start
in the get_w(X, y)
function in the book (see here).
with pm.Model() as model:
x = pm.MutableData(name="x", value=x_pre)
y = pm.MutableData(name="y", value=y_pre)
beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
sigma = pm.HalfNormal(name="sigma", sigma=5)
mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)
pm.model_to_graphviz(model)
We now sample from the mode to get the posterior distributions.
with model:
idata = pm.sampling_jax.sample_numpyro_nuts(draws=4000, chains=4)
posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)
Let’s plot the posterior distributions for the weights.
az.plot_forest(data=idata, combined=True, var_names=["beta"]);
Note that many of the weights are zero and the non-zero weights coincide (the mean) with the weights from the book. We can confirm that the weights satisfy the constraints, namely that they are all positive and add up to one.
np.unique(
idata
.posterior["beta"]
.stack(samples=("chain", "draw"))
.sum(axis=0)
.to_numpy()
- 1
)
array([-5.55111512e-16, -4.44089210e-16, -3.33066907e-16, -2.22044605e-16, -1.11022302e-16, 0.00000000e+00, 2.22044605e-16, 4.44089210e-16])
Posterior Predictive
We can now generate predictions for the post treatment years.
with model:
pm.set_data(new_data={"x": x_post, "y": y_post})
posterior_predictive_post = pm.sample_posterior_predictive(
trace=idata, var_names=["likelihood"]
)
We can now visualize the results (sample plot as the book but this time with credible intervals!).
pre_posterior_mean = (
posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre]
.stack(samples=("chain", "draw"))
.mean(axis=1)
)
post_posterior_mean = (
posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post]
.stack(samples=("chain", "draw"))
.mean(axis=1)
)
fig, ax = plt.subplots()
(
df.groupby(["year", "california"], as_index=False)
.agg({"cigsale": np.mean})
.assign(
california=lambda x: x.california.map(
{True: "is_california", False: "is_not_california"}
)
)
.pipe(
(sns.lineplot, "data"),
x="year",
y="cigsale",
hue="california",
alpha=0.5,
ax=ax,
)
)
ax.axvline(
x=1988,
linestyle=":",
lw=2,
color="C2",
label="Proposition 99",
)
sns.lineplot(
x=pre_years,
y=pre_posterior_mean,
color="C1",
marker="o",
label="pre-treatment posterior predictive mean",
ax=ax,
)
sns.lineplot(
x=post_years,
y=post_posterior_mean,
color="C2",
marker="o",
label="post-treatment posterior predictive mean",
ax=ax,
)
az.plot_hdi(
x=pre_years,
y=posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre],
smooth=True,
color="C1",
fill_kwargs={"label": "pre-treatment posterior predictive (94% HDI)"},
ax=ax,
)
az.plot_hdi(
x=post_years,
y=posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post],
smooth=True,
color="C2",
fill_kwargs={"label": "post-treatment posterior predictive (94% HDI)"},
ax=ax,
)
ax.legend(loc="lower left")
ax.set(
title="Gap in per-capita cigarette sales (in packs)", ylabel="Cigarette Sales Trend"
);
We can also plot the “errors”.
fig, ax = plt.subplots()
ax.axvline(
x=1988,
linestyle=":",
lw=2,
color="C2",
label="Proposition 99",
)
sns.lineplot(
x=pre_years,
y=y_pre[:n_pre] - pre_posterior_mean,
color="C1",
marker="o",
label="pre-treatment posterior predictive effect mean",
ax=ax,
)
sns.lineplot(
x=post_years,
y=y_post[:n_post] - post_posterior_mean,
color="C2",
marker="o",
label="post-treatment posterior predictive effect mean",
ax=ax,
)
az.plot_hdi(
x=pre_years,
y=y_pre[:n_pre]
- posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre],
smooth=True,
color="C1",
fill_kwargs={"label": "pre-treatment posterior predictive effect (94% HDI)"},
ax=ax,
)
az.plot_hdi(
x=post_years,
y=y_post[:n_post]
- posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post],
smooth=True,
color="C2",
fill_kwargs={"label": "post-treatment posterior predictive effect (94% HDI)"},
ax=ax,
)
ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
title="State - Synthetic Across Time",
ylabel="Gap in per-capita cigarette sales (in packs)",
);
Note that in the pre-treatment years the model errors oscillate around zero whereas in the post-treatment period they have a negative trend.
Finally, we can take a look into the reduced sales year distribution for the year \(2000\).
g = (
(
y_post[:n_post]
- posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post]
)[:, :, -1]
.stack(samples=("chain", "draw"))
.pipe((sns.displot, "data"), kde=True, height=4.5, aspect=1.5)
)
g.set(title="reduced the sales in cigarettes in 2000");
Hence we obtain the same result as in the book:
By the year 2000, it looks like Proposition 99 has reduced the sales in cigarettes by 25 packs.
However, the bayesian framework gives additional valuable information about the model’s uncertainty.
Making Inference
In this final section we want to implement the procedure described in the book to asses significance (see here). The main idea is to run the model above for all the states and compare the distribution of the resulting synthetic control estimates. First we simply write a function that wraps the main steps (its a long function which could be modularized, but I keep it explicit for the purpose of the exposition).
def run_synthetic_control(
pre_df: pd.DataFrame, post_df: pd.DataFrame, idx: int
) -> tuple:
# prepare data
y_pre = pre_df[idx].to_numpy()
x_pre = pre_df.drop(columns=idx).to_numpy()
pre_years = pre_df.reset_index(inplace=False).year.unique()
n_pre = pre_years.size
y_post = post_df[idx].to_numpy()
x_post = post_df.drop(columns=idx).to_numpy()
post_years = post_df.reset_index(inplace=False).year.unique()
n_post = post_years.size
k = pre_df.shape[1] - 1
# specify the model
with pm.Model() as model:
x = pm.MutableData(name="x", value=x_pre)
y = pm.MutableData(name="y", value=y_pre)
beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
sigma = pm.HalfNormal(name="sigma", sigma=5)
mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)
# fit the model
idata = pm.sampling_jax.sample_numpyro_nuts(draws=4000, chains=4)
posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)
# post-treatment predictive distribution
pm.set_data(new_data={"x": x_post, "y": y_post})
posterior_predictive_post = pm.sample_posterior_predictive(
trace=idata, var_names=["likelihood"]
)
# compute errors
error_pre = (
y_pre[:n_pre]
- posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre]
)
error_post = (
y_post[:n_post]
- posterior_predictive_post.posterior_predictive["likelihood"][
:, :, :n_post
]
)
return error_pre, error_post
Now we run it (it was to run 39 bayesian model but with the JAX backend it takes less than an hour in my local machine).
from tqdm.notebook import tqdm
results = {
idx: run_synthetic_control(pre_df=pre_df, post_df=post_df, idx=idx)
for idx in tqdm(df["state"].unique())
}
Let’s plot the results. First we look into the mean of the posterior distributions. We do ths to verify we obtain the same results as in the book (which we do!).
Remark: We filter out certain states on which the synthetic control does not work well in the pre-treatment period. The book filters them out using the mean squared error. Here we filter out based on the standard deviation of the error posterior distribution (not super important).
fig, ax = plt.subplots()
for idx in df["state"].unique():
error_pre, error_post = results[idx]
sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
if sigma_pre < 10:
color = "C6" if idx == 3 else "gray"
alpha = 1 if idx == 3 else 0.3
label = "california" if idx == 3 else None
sns.lineplot(
x=pre_years,
y=error_pre.stack(samples=("chain", "draw")).mean(axis=1),
color=color,
alpha=alpha,
ax=ax,
)
sns.lineplot(
x=post_years,
y=error_post.stack(samples=("chain", "draw")).mean(axis=1),
color=color,
alpha=alpha,
label=label,
ax=ax,
)
ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
title="State - Synthetic Across Time",
ylabel="Gap in per-capita cigarette sales (in packs)",
);
Now we can plot the whole posterior distributions.
fig, ax = plt.subplots()
for idx in df["state"].unique():
error_pre, error_post = results[idx]
sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
if sigma_pre < 10:
color = "C6" if idx == 3 else "gray"
alpha = 1 if idx == 3 else 0.05
label = "california" if idx == 3 else None
az.plot_hdi(
x=pre_years,
y=error_pre,
smooth=True,
color=color,
fill_kwargs={"alpha": alpha},
ax=ax,
)
az.plot_hdi(
x=post_years,
y=error_post,
smooth=True,
color=color,
fill_kwargs={"alpha": alpha, "label": label},
ax=ax,
)
ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
title="State - Synthetic Across Time",
ylabel="Gap in per-capita cigarette sales (in packs)",
);
Overall, it looks California has an “extreme” value as compared with all the other states. Let’s look into the highest density intervals (HDI) for the states when looking into the year \(2000\).
fig, ax = plt.subplots()
for idx in df["state"].unique():
error_pre, error_post = results[idx]
sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
if sigma_pre < 10:
color = "C6" if idx == 3 else "gray"
alpha = 1 if idx == 3 else 0.8
label = "california" if idx == 3 else None
hdi = az.hdi(error_post[:, :, -1])["likelihood"].to_numpy()
ax.vlines(
x=idx,
ymin=hdi[0],
ymax=hdi[1],
linestyle="solid",
linewidth=3,
color=color,
alpha=alpha,
label=label,
)
ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="upper left")
ax.set(
title="94% HDI reduced the sales in cigarettes in 2000",
xlabel="sate index",
ylabel="Gap in per-capita cigarette sales (in packs)",
);
Here we see that the HDI for California is indeed quite small an low. We can now compute the proportion of times the effect in California is bigger than all the estimated effects. First let is see the explicit ids of the states we are filtering out.
for idx in df["state"].unique():
error_pre, error_post = results[idx]
sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
if sigma_pre >= 10:
print(idx)
13
22
24
34
Now we “simply count”:
cal_samples = np.array(
[
v[1].stack(samples=("chain", "draw")).T.to_numpy()
for k, v in results.items()
if k == 3
]
)
non_cal_samples = np.array(
[
v[1].stack(samples=("chain", "draw")).T.to_numpy()
for k, v in results.items()
if k not in [3, 13, 22, 24, 34]
]
)
samples = (cal_samples[:, :, -1] > non_cal_samples[:, :, -1]).mean(axis=0)
fig, ax = plt.subplots()
sns.histplot(
x=samples,
color="C4",
label=f"mean: {samples.mean(): .3f} and std = {samples.std(): .3f}",
ax=ax,
)
ax.legend()
ax.set(
title="Proportion of times the effect in California is bigger than all the estimated effects"
);
Finally, let’s visualize the distributions:
fig, ax = plt.subplots()
sns.histplot(x=non_cal_samples[:, :, -1].flatten(), color="gray", stat="density", ax=ax)
sns.histplot(
x=cal_samples[:, :, -1].flatten(),
color="C6",
stat="density",
label="california",
ax=ax,
)
ax.legend(loc="upper right")
ax.set(
title="Histograms for year 2000",
xlabel="Gap in per-capita cigarette sales (in packs)",
ylabel="frequency",
);