In this notebook we present a simple approach to study cohort retention analysis through a simulated data set. The aim is to understand how retention rates change over time and provide a simple model to predict them (with uncertainty estimates!). We do not expect this technique to be a silver bullet for all retention problems, but rather a simple approach to get started with the problem.
Remark: A motivation for this notebook was the great post Bayesian Age/Period/Cohort Models in Python with PyMC by Austin Rochford.
Prepare Notebook
import arviz as az
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import seaborn as sns
from scipy.special import expit
from sklearn.preprocessing import LabelEncoder, StandardScaler
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generate Data
In this section we generate some synthetic retention data. We assume we are interested in monthly retention rates. Here is the logic behind the data generation process: For a given month, say 2020-01
, we acquired \(100\) users from which \(20\) returned back (say, did a purchase) in 2020-02
. The retention rate for 2020-02
for the 2020-01
cohort is then \(20/100 = 0.2\). We simulate cohort retention rates based on the following assumptions:
- The retention rate is a function of the cohort age (i.e. how many months the cohort has been alive) and the absolute age (i.e. how many months have passed since the first month of the cohort since today).
- The retention rate has a yearly seasonality component and a trend component.
We do not simulate retention rates directly but rather the number of users that returned back in a given month. We use a Binomial distribution for this purpose.
The following CohortDataGenerator
class implements the data generation process.
class CohortDataGenerator:
def __init__(
self,
rng: np.random.Generator,
start_cohort: str,
n_cohorts,
user_base: int = 10_000,
) -> None:
self.rng = rng
self.start_cohort = start_cohort
self.n_cohorts = n_cohorts
self.user_base = user_base
def _generate_cohort_labels(self) -> pd.DatetimeIndex:
return pd.period_range(
start="2020-01-01", periods=self.n_cohorts, freq="M"
).to_timestamp()
def _generate_cohort_sizes(self) -> npt.NDArray[np.int_]:
ones = np.ones(shape=self.n_cohorts)
trend = ones.cumsum() / ones.sum()
return (
(
self.user_base
* trend
* self.rng.gamma(shape=1, scale=1, size=self.n_cohorts)
)
.round()
.astype(int)
)
def _generate_dataset_base(self) -> pd.DataFrame:
cohorts = self._generate_cohort_labels()
n_users = self._generate_cohort_sizes()
data_df = pd.merge(
left=pd.DataFrame(data={"cohort": cohorts, "n_users": n_users}),
right=pd.DataFrame(data={"period": cohorts}),
how="cross",
)
data_df["age"] = (data_df["period"].max() - data_df["cohort"]).dt.days
data_df["cohort_age"] = (data_df["period"] - data_df["cohort"]).dt.days
data_df = data_df.query("cohort_age >= 0")
return data_df
def _generate_retention_rates(self, data_df: pd.DataFrame) -> pd.DataFrame:
data_df["retention_true_mu"] = (
-data_df["cohort_age"] / (data_df["age"] + 1)
+ 0.8 * np.cos(2 * np.pi * data_df["period"].dt.dayofyear / 365)
+ 0.5 * np.sin(2 * 3 * np.pi * data_df["period"].dt.dayofyear / 365)
- 0.5 * np.log1p(data_df["age"])
+ 1.0
)
data_df["retention_true"] = expit(data_df["retention_true_mu"])
return data_df
def _generate_user_history(self, data_df: pd.DataFrame) -> pd.DataFrame:
data_df["n_active_users"] = self.rng.binomial(
n=data_df["n_users"], p=data_df["retention_true"]
)
data_df["n_active_users"] = np.where(
data_df["cohort_age"] == 0, data_df["n_users"], data_df["n_active_users"]
)
return data_df
def run(
self,
) -> pd.DataFrame:
return (
self._generate_dataset_base()
.pipe(self._generate_retention_rates)
.pipe(self._generate_user_history)
)
Let’s generate the data for \(48\) cohorts.
seed: int = sum(map(ord, "retention"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
start_cohort: str = "2020-01-01"
n_cohorts: int = 48
cohort_generator = CohortDataGenerator(rng=rng, start_cohort=start_cohort, n_cohorts=n_cohorts)
data_df = cohort_generator.run()
# calculate retention rates
data_df["retention"] = data_df["n_active_users"] / data_df["n_users"]
data_df.head()
cohort | n_users | period | age | cohort_age | retention_true_mu | retention_true | n_active_users | retention | |
---|---|---|---|---|---|---|---|---|---|
0 | 2020-01-01 | 150 | 2020-01-01 | 1430 | 0 | -1.807373 | 0.140956 | 150 | 1.000000 |
1 | 2020-01-01 | 150 | 2020-02-01 | 1430 | 31 | -1.474736 | 0.186224 | 25 | 0.166667 |
2 | 2020-01-01 | 150 | 2020-03-01 | 1430 | 60 | -2.281286 | 0.092685 | 13 | 0.086667 |
3 | 2020-01-01 | 150 | 2020-04-01 | 1430 | 91 | -3.206610 | 0.038918 | 6 | 0.040000 |
4 | 2020-01-01 | 150 | 2020-05-01 | 1430 | 121 | -3.112983 | 0.042575 | 2 | 0.013333 |
Now we do a train test split to evaluate our model out-of-sample predictions.
period_train_test_split = "2022-11-01"
train_data_df = data_df.query("period <= @period_train_test_split")
test_data_df = data_df.query("period > @period_train_test_split")
test_data_df = test_data_df[test_data_df["cohort"].isin(train_data_df['cohort'].unique())]
EDA
Now, let’s have a look at the data. We restrict ourselves to the training data as we really want to use the test data as a hold-out-set. First, we plot the number of users per cohort and the number of users that returned back per cohort.
fig, ax = plt.subplots(
nrows=2, ncols=1, sharex=True, sharey=False, figsize=(12, 7), layout="constrained"
)
(
train_data_df[["cohort", "n_users"]]
.drop_duplicates()
.set_index("cohort")
.plot(color="C0", marker="o", ax=ax[0])
)
ax[0].set(title="Number of users per cohort", xlabel="cohort", ylabel="number of users")
for cohort in train_data_df["cohort"].unique()[:-1]:
train_data_df.query("cohort == @cohort and cohort_age > 0").set_index("period")[
"n_active_users"
].plot(color="black", alpha=0.5, ax=ax[1])
ax[1].set(
title="Number of active users per cohort (log scale)",
xlabel="period",
ylabel="number of active users",
yscale="log",
);
- Number of users: We see a mild positive trend with some peaks which do not seem driven by seasonality.
- Number of returning users: We see a clear seasonality component with a peak in the winter months.
Now we plot the retention rates per cohort.
fig, ax = plt.subplots(figsize=(18, 7))
fmt = lambda y, _: f"{y :0.0%}"
(
train_data_df.assign(
cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
period=lambda df: df["period"].dt.strftime("%Y-%m"),
)
.query("cohort_age != 0")
.filter(["cohort", "period", "retention"])
.pivot(index="cohort", columns="period", values="retention")
.pipe(
(sns.heatmap, "data"),
cmap="viridis_r",
linewidths=0.2,
linecolor="black",
annot=True,
fmt="0.0%",
cbar_kws={"format": mtick.FuncFormatter(fmt)},
ax=ax,
)
)
ax.set_title("Retention by Cohort and Period");
fig, ax = plt.subplots(figsize=(12, 7))
sns.lineplot(
x="period",
y="retention",
hue="cohort",
palette="viridis_r",
alpha=0.8,
data=train_data_df.query("cohort_age > 0").assign(
cohort=lambda df: df["cohort"].dt.strftime("%Y-%m")
),
ax=ax,
)
ax.legend(title="cohort", loc="center left", bbox_to_anchor=(1, 0.5), fontsize=7.5)
ax.set(title="Retention by Cohort and Period");
- It seems that for a given period, the retention rates of the new cohorts are higher than the retention rates of the older cohorts. This is a clear indication that the retention rate is a function of the absolute cohort age.
- We also see a clear seasonality component in the retention rates.
- For a given cohort, seasonality peaks are decreasing as a function of time (period).
Let’s have a look at this last point in more detail. We plot the retention rates per month:
g = sns.relplot(
x="period",
y="retention",
hue="cohort",
palette="viridis_r",
col="month",
col_wrap=3,
kind="line",
marker="o",
alpha=0.8,
data=train_data_df.query("cohort_age > 0").assign(
cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
month=lambda df: df["period"].dt.strftime("%b")
),
height=2,
aspect=3,
facet_kws={"sharex": True, "sharey": True},
)
Finally, we plot the retention rates for each period (observation time) as a function of the cohort age (i.e. how long is the cohort):
fig, ax = plt.subplots(figsize=(12, 7))
sns.lineplot(
x="cohort_age",
y="retention",
hue="period",
palette="viridis_r",
alpha=0.8,
data=train_data_df.query("cohort_age > 0").assign(
period=lambda df: df["period"].dt.strftime("%Y-%m")
),
ax=ax,
)
ax.legend(title="period", loc="center left", bbox_to_anchor=(1, 0.5), fontsize=7.5)
ax.set(title="Retention by Cohort and Period");
Most of them have a negative trend. This is just representing the vertical columns from the retention matrix above.
One of the main takeaways from this EDA is that the retention rate is a function of (after controlling for seasonality) the absolute cohort age and the period, with a potential interaction effect.
Model
We now want to model the retention rates based in the EDA above. We use a generalized linear model with binomial likelihood and a logit link function. Concretely,
\[\begin{align*} N_{\text{active}} \sim &\text{Binomial}(N_{\text{total}}, p) \\ \textrm{logit}(p) = (& \text{intercept} \\ & + \beta_{\text{cohort age}} \text{cohort age} \\ & + \beta_{\text{age}} \text{age} \\ & + \beta_{\text{cohort age} \times \text{age}} \text{cohort age} \times \text{age} \\ & + \beta_{\text{seasonality}} \text{seasonality} ) \end{align*}\]
We also choose normal priors which are not very informative.
Remarks
For the seasonality we use dummy monthly indicators. In addition, we use a
ZeroSumNormal
distribution to model these terms as we re interested in the relative effect (note that we have an intercept in the model to account for the baseline).For the numerical variables (
cohort age
andage
) we scale the data.
Before we jump into the data transformations and model fitting, let’s have a look at the data structure:
train_data_df.head()
cohort | n_users | period | age | cohort_age | retention_true_mu | retention_true | n_active_users | retention | |
---|---|---|---|---|---|---|---|---|---|
0 | 2020-01-01 | 150 | 2020-01-01 | 1430 | 0 | -1.807373 | 0.140956 | 150 | 1.000000 |
1 | 2020-01-01 | 150 | 2020-02-01 | 1430 | 31 | -1.474736 | 0.186224 | 25 | 0.166667 |
2 | 2020-01-01 | 150 | 2020-03-01 | 1430 | 60 | -2.281286 | 0.092685 | 13 | 0.086667 |
3 | 2020-01-01 | 150 | 2020-04-01 | 1430 | 91 | -3.206610 | 0.038918 | 6 | 0.040000 |
4 | 2020-01-01 | 150 | 2020-05-01 | 1430 | 121 | -3.112983 | 0.042575 | 2 | 0.013333 |
Data Transformations
train_data_red_df = train_data_df.query("cohort_age > 0").reset_index(drop=True)
train_obs_idx = train_data_red_df.index.to_numpy()
train_n_users = train_data_red_df["n_users"].to_numpy()
train_n_active_users = train_data_red_df["n_active_users"].to_numpy()
train_retention = train_data_red_df["retention"].to_numpy()
# Continuous features
train_age = train_data_red_df["age"].to_numpy()
train_age_scaler = StandardScaler(with_mean=False)
train_age_scaled = train_age_scaler.fit_transform(train_age.reshape(-1, 1)).flatten()
train_cohort_age = train_data_red_df["cohort_age"].to_numpy()
train_cohort_age_scaler = StandardScaler(with_mean=False)
train_cohort_age_scaled = train_cohort_age_scaler.fit_transform(
train_cohort_age.reshape(-1, 1)
).flatten()
# Categorical features
train_cohort = train_data_red_df["cohort"].to_numpy()
train_cohort_encoder = LabelEncoder()
train_cohort_idx = train_cohort_encoder.fit_transform(train_cohort).flatten()
train_period = train_data_red_df["period"].to_numpy()
train_period_encoder = LabelEncoder()
train_period_idx = train_period_encoder.fit_transform(train_period).flatten()
train_period_month = train_data_red_df["period"].dt.month.to_numpy()
train_period_month_encoder = LabelEncoder()
train_period_month_idx = train_period_month_encoder.fit_transform(
train_period_month
).flatten()
coords: dict[str, npt.NDArray] = {
"period_month": train_period_month_encoder.classes_,
}
Remark: Note that we know all features at prediction time!
Model Specification
The following PyMC model implements the model specification above. We use pm.MutableData
wrappers so that we can update the model with new data when generating predictions.
with pm.Model(coords=coords) as model:
# --- Data ---
model.add_coord(name="obs", values=train_obs_idx, mutable=True)
age_scaled = pm.MutableData(name="age_scaled", value=train_age_scaled, dims="obs")
cohort_age_scaled = pm.MutableData(
name="cohort_age_scaled", value=train_cohort_age_scaled, dims="obs"
)
period_month_idx = pm.MutableData(
name="period_month_idx", value=train_period_month_idx, dims="obs"
)
n_users = pm.MutableData(name="n_users", value=train_n_users, dims="obs")
n_active_users = pm.MutableData(
name="n_active_users", value=train_n_active_users, dims="obs"
)
# --- Priors ---
intercept = pm.Normal(name="intercept", mu=0, sigma=1)
b_age_scaled = pm.Normal(name="b_age_scaled", mu=0, sigma=1)
b_cohort_age_scaled = pm.Normal(name="b_cohort_age_scaled", mu=0, sigma=1)
b_period_month = pm.ZeroSumNormal(
name="b_period_month", sigma=1, dims="period_month"
)
b_age_cohort_age_interaction = pm.Normal(
name="b_age_cohort_age_interaction", mu=0, sigma=1
)
# --- Parametrization ---
mu = pm.Deterministic(
name="mu",
var=intercept
+ b_age_scaled * age_scaled
+ b_cohort_age_scaled * cohort_age_scaled
+ b_age_cohort_age_interaction * age_scaled * cohort_age_scaled
+ b_period_month[period_month_idx],
dims="obs",
)
p = pm.Deterministic(name="p", var=pm.math.invlogit(mu), dims="obs")
# --- Likelihood ---
pm.Binomial(
name="likelihood",
n=n_users,
p=p,
observed=n_active_users,
dims="obs",
)
pm.model_to_graphviz(model=model)
Model Fitting
with model:
idata = pm.sampling_jax.sample_numpyro_nuts(
target_accept=0.8, draws=2_000, chains=4
)
posterior_predictive = pm.sample_posterior_predictive(trace=idata)
Model Diagnostics
The model samples nicely and diagnostics look good:
az.summary(data=idata, var_names=["~mu", "~p"])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
intercept | -1.856 | 0.023 | -1.901 | -1.814 | 0.0 | 0.0 | 5196.0 | 5384.0 | 1.0 |
b_age_scaled | -0.141 | 0.007 | -0.154 | -0.128 | 0.0 | 0.0 | 5786.0 | 5706.0 | 1.0 |
b_cohort_age_scaled | -0.568 | 0.026 | -0.615 | -0.517 | 0.0 | 0.0 | 5691.0 | 5961.0 | 1.0 |
b_age_cohort_age_interaction | 0.075 | 0.006 | 0.064 | 0.086 | 0.0 | 0.0 | 5258.0 | 5233.0 | 1.0 |
b_period_month[1] | 0.806 | 0.012 | 0.783 | 0.828 | 0.0 | 0.0 | 13899.0 | 5449.0 | 1.0 |
b_period_month[2] | 1.178 | 0.011 | 1.157 | 1.198 | 0.0 | 0.0 | 13509.0 | 5804.0 | 1.0 |
b_period_month[3] | 0.406 | 0.013 | 0.382 | 0.431 | 0.0 | 0.0 | 12560.0 | 6055.0 | 1.0 |
b_period_month[4] | -0.507 | 0.018 | -0.541 | -0.473 | 0.0 | 0.0 | 11474.0 | 5408.0 | 1.0 |
b_period_month[5] | -0.461 | 0.018 | -0.494 | -0.428 | 0.0 | 0.0 | 11166.0 | 5476.0 | 1.0 |
b_period_month[6] | -0.182 | 0.014 | -0.207 | -0.156 | 0.0 | 0.0 | 13638.0 | 5294.0 | 1.0 |
b_period_month[7] | -0.808 | 0.016 | -0.837 | -0.776 | 0.0 | 0.0 | 12110.0 | 5431.0 | 1.0 |
b_period_month[8] | -1.182 | 0.019 | -1.217 | -1.148 | 0.0 | 0.0 | 11723.0 | 5748.0 | 1.0 |
b_period_month[9] | -0.351 | 0.013 | -0.376 | -0.327 | 0.0 | 0.0 | 11346.0 | 5962.0 | 1.0 |
b_period_month[10] | 0.511 | 0.009 | 0.493 | 0.529 | 0.0 | 0.0 | 12123.0 | 6523.0 | 1.0 |
b_period_month[11] | 0.394 | 0.010 | 0.376 | 0.414 | 0.0 | 0.0 | 10509.0 | 5545.0 | 1.0 |
b_period_month[12] | 0.195 | 0.017 | 0.163 | 0.228 | 0.0 | 0.0 | 8672.0 | 6512.0 | 1.0 |
_ = az.plot_trace(
data=idata,
var_names=["~mu", "~p"],
compact=True,
backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
ax, *_ = az.plot_forest(data=idata, var_names=["~mu", "~p"], combined=True)
ax.axvline(x=0, color="black", linestyle="--")
ax.set(title="Model Posterior Predictive", xlabel="Retention");
It is nice to see how the seasonality coefficients resemble seasonality patterns from the EDA.
We can now look into the posterior predictive check:
ax = az.plot_ppc(data=posterior_predictive, kind="cumulative", observed_rug=True)
ax.set(
title="Posterior Predictive Check",
xscale="log",
xlabel="likelihood (n_active_users) - log scale",
);
Retention Rate In-Sample Predictions
Now we can use the model to predict the retention rates for the training data. We plot the observed retention rates and the predicted (posterior mean) retention rates.
train_posterior_retention = (
posterior_predictive.posterior_predictive / train_n_users[np.newaxis, None]
)
train_posterior_retention_mean = az.extract(
data=train_posterior_retention, var_names=["likelihood"]
).mean("sample")
fig, ax = plt.subplots(figsize=(10, 9))
sns.scatterplot(
x="retention",
y="posterior_retention_mean",
data=train_data_red_df.assign(
posterior_retention_mean=train_posterior_retention_mean
),
hue="age",
palette="viridis_r",
size="n_users",
ax=ax,
)
ax.axline(xy1=(0, 0), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend()
ax.set(title="Posterior Predictive - Retention Mean");
The results look quite good 🚀 !
Now, we can deep dive into specific cohorts to see the predictions and uncertainty of the estimates.
train_retention_hdi = az.hdi(ary=train_posterior_retention)["likelihood"]
def plot_train_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
mask = train_cohort_idx == cohort_index
ax.fill_between(
x=train_period[train_period_idx[mask]],
y1=train_retention_hdi[mask, :][:, 0],
y2=train_retention_hdi[mask, :][:, 1],
alpha=0.3,
color="C0",
label="94% HDI (train)",
)
sns.lineplot(
x=train_period[train_period_idx[mask]],
y=train_retention[mask],
color="C0",
marker="o",
label="observed retention (train)",
ax=ax,
)
cohort_name = (
pd.to_datetime(train_cohort_encoder.classes_[cohort_index]).date().isoformat()
)
ax.legend(loc="upper left")
ax.set(title=f"Retention HDI - Cohort {cohort_name}")
return ax
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]
fig, axes = plt.subplots(
nrows=np.ceil(len(cohort_index_to_plot) / 2).astype(int),
ncols=2,
figsize=(15, 10),
sharex=True,
sharey=True,
layout="constrained",
)
for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
Here we also see the model is capturing the retention rate development over time. In addition we see how older cohorts have a higher uncertainty estimation as the number of base users is lower than the new cohorts.
Predictions
Now we transform the test data to the same format as the training data and use the model to predict the retention rates. Note that we are using the scalers and encoders from the training data.
Data Transformations
test_data_red_df = test_data_df.query("cohort_age > 0")
test_data_red_df = test_data_red_df[
test_data_red_df["cohort"].isin(train_data_red_df["cohort"].unique())
].reset_index(drop=True)
test_obs_idx = test_data_red_df.index.to_numpy()
test_n_users = test_data_red_df["n_users"].to_numpy()
test_n_active_users = test_data_red_df["n_active_users"].to_numpy()
test_retention = test_data_red_df["retention"].to_numpy()
# Continuous features
test_cohort = test_data_red_df["cohort"].to_numpy()
test_cohort_idx = train_cohort_encoder.transform(test_cohort).flatten()
test_age = test_data_red_df["age"].to_numpy()
test_age_scaled = train_age_scaler.transform(test_age.reshape(-1, 1)).flatten()
test_cohort_age = test_data_red_df["cohort_age"].to_numpy()
test_cohort_age_scaled = train_cohort_age_scaler.transform(
test_cohort_age.reshape(-1, 1)
).flatten()
# Categorical features
test_period_month = test_data_red_df["period"].dt.month.to_numpy()
test_period_month_idx = train_period_month_encoder.fit_transform(
test_period_month
).flatten()
Out-of-Sample Posterior Predictions
We now replace the model data with the test data and generate posterior predictions.
with model:
pm.set_data(
new_data={
"age_scaled": test_age_scaled,
"cohort_age_scaled": test_cohort_age_scaled,
"period_month_idx": test_period_month_idx,
"n_users": test_n_users,
"n_active_users": np.ones_like(
test_n_active_users
), # Dummy data to make coords work! We are not using this at prediction time!
},
coords={"obs": test_obs_idx},
)
idata.extend(
pm.sample_posterior_predictive(
trace=idata,
var_names=["likelihood", "p", "mu"],
idata_kwargs={"coords": {"obs": test_obs_idx}},
)
)
Retention Rate Out-of-Sample Predictions
Finally we compute the posterior retention rate distributions for the test data and visualize the results.
test_posterior_retention = (
idata.posterior_predictive["likelihood"] / test_n_users[np.newaxis, None]
)
test_retention_hdi = az.hdi(ary=test_posterior_retention)["likelihood"]
def plot_test_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
mask = test_cohort_idx == cohort_index
test_period_range = test_data_red_df.query(
f"cohort == '{train_cohort_encoder.classes_[cohort_index]}'"
)["period"]
ax.fill_between(
x=test_period_range,
y1=test_retention_hdi[mask, :][:, 0],
y2=test_retention_hdi[mask, :][:, 1],
alpha=0.3,
color="C1",
label="94% HDI (test)",
)
sns.lineplot(
x=test_period_range,
y=test_retention[mask],
color="C1",
marker="o",
label="observed retention (test)",
ax=ax,
)
return ax
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]
fig, axes = plt.subplots(
nrows=len(cohort_index_to_plot),
ncols=1,
figsize=(12, 15),
sharex=True,
sharey=True,
layout="constrained",
)
for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten()):
plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
plot_test_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
ax.axvline(
x=pd.to_datetime(period_train_test_split),
color="black",
linestyle="--",
label="train/test split",
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
fig.suptitle("Retention Predictions", y=1.03, fontsize=16);
The out-of-sample predictions look quite good as well 🙌 !
For real data sets we expect this model to be a baseline for further improvements. For example, we could add more features (e.g. country, user segment, etc.) or we could use a more complex model (e.g. a hierarchical model).