12 min read

Cohort Revenue Retention Analysis with Flax and NumPyro

In this notebook we present an alternative implementation of the cohort-revenue-retention model presented in the blog post Cohort Revenue & Retention Analysis: A Bayesian Approach where we show how to replace the BART retention component with a general neural network implemented with Flax. This allows faster inference, as we can use NumPyro’s NUTS sampler or any of the stochastic variational inference (SVI) algorithms available. We could even use a wider family of samplers using the newly released package Bayeux or the great BlackJax (see for example, the MLP Classifier Example). We use the same simulated dataset to be able to compare the approaches. Overall, the retention and revenue in and out-of sample predictions, as well as the credible intervals are very similar to the ones obtained with the BART model.

Remark: On the other hand, we loose the PDP and ICE plots, which are useful to understand the influence of regressors on the target variable. We could of course trying implementing it by hand, but it would not be straightforward to make it fast, at least in my view (please let me know if you have any ideas on how to do it.)

Prepare Notebook

import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import seaborn as sns
from flax import linen as nn
from jax import random
from numpyro.contrib.module import random_flax_module
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import Predictive
from scipy.special import logit
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
    LabelEncoder,
    MaxAbsScaler,
    OneHotEncoder,
    StandardScaler,
)

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
seed: int = sum(map(ord, "retention"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Read Data

We start by reading the data from previous posts (see here for the code to generate the data).

data_df = pd.read_csv(
    "https://raw.githubusercontent.com/juanitorduz/website_projects/master/data/retention_data.csv",
    parse_dates=["cohort", "period"],
)

data_df.head()
cohort n_users period age cohort_age retention_true_mu retention_true n_active_users revenue retention
0 2020-01-01 150 2020-01-01 1430 0 -1.807373 0.140956 150 14019.256906 1.000000
1 2020-01-01 150 2020-02-01 1430 31 -1.474736 0.186224 25 1886.501237 0.166667
2 2020-01-01 150 2020-03-01 1430 60 -2.281286 0.092685 13 1098.136314 0.086667
3 2020-01-01 150 2020-04-01 1430 91 -3.206610 0.038918 6 477.852458 0.040000
4 2020-01-01 150 2020-05-01 1430 121 -3.112983 0.042575 2 214.667937 0.013333

Data Preprocessing

We make a data train-test split as in the previous post.

period_train_test_split = "2022-11-01"

train_data_df = data_df.query("period <= @period_train_test_split")
test_data_df = data_df.query("period > @period_train_test_split")
test_data_df = test_data_df[
    test_data_df["cohort"].isin(train_data_df["cohort"].unique())
]

EDA

For a detailed EDA of the data, please refer to the previous posts (A Simple Cohort Retention Analysis in PyMC, Cohort Retention Analysis with BART and Cohort Revenue & Retention Analysis: A Bayesian Approach). We assume you are familiar with the dataset.

Let’s recall how the retention matrix looks like:

fig, ax = plt.subplots(figsize=(17, 9))

(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "retention"])
    .pivot_table(index="cohort", columns="period", values="retention")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        fmt="0.0%",
        cbar_kws={"format": mtick.FuncFormatter(func=lambda y, _: f"{y :0.0%}")},
        ax=ax,
    )
)

ax.set_title("Retention by Cohort and Period")

Similarly we can plot the revenue matrix:

fig, ax = plt.subplots(figsize=(17, 9))


(
    train_data_df.assign(
        cohort=lambda df: df["cohort"].dt.strftime("%Y-%m"),
        period=lambda df: df["period"].dt.strftime("%Y-%m"),
    )
    .query("cohort_age != 0")
    .filter(["cohort", "period", "revenue"])
    .pivot_table(index="cohort", columns="period", values="revenue")
    .pipe(
        (sns.heatmap, "data"),
        cmap="viridis_r",
        linewidths=0.2,
        linecolor="black",
        annot=True,
        annot_kws={"fontsize": 6},
        fmt="0.0f",
        cbar_kws={"format": mtick.FuncFormatter(lambda y, _: f"{y :0.0f}")},
        ax=ax,
    )
)

ax.set_title("Revenue by Cohort and Period")

Our objective is to model both components as the same time as we expect the revenue to depend on the retention levels.

Model

Motivated by the analysis above we suggest the following retention-revenue model.

\[\begin{align*} \text{Revenue} & \sim \text{Gamma}(N_{\text{active}}, \lambda) \\ \log(\lambda) = (& \text{intercept} \\ & + \beta_{\text{cohort age}} \text{cohort age} \\ & + \beta_{\text{age}} \text{age} \\ & + \beta_{\text{cohort age} \times \text{age}} \text{cohort age} \times \text{age} ) \\ N_{\text{active}} & \sim \text{Binomial}(N_{\text{total}}, p) \\ \textrm{logit}(p) & = \text{NN}(\text{cohort age}, \text{age}, \text{month}) \end{align*}\]

where \(\text{NN}\) is a neural network implemented with Flax. We use a simple architecture with one hidden layers with \(4\) units each and sigmoid activation functions. For this simple case this architecture is enough. You can of course try more complex models.

The magic for this approach lies in the NumPyro module numpyro/contrib/module.py where we have very useful functions to integrate Flax models with NumPyro. In particular, we will use the random_flax_module which allow us to set priors on the layers weights and biases. You can find a simple example of this approach in the blog post Flax and NumPyro Toy Example.

Data Transformations

We do similar transformations as in the previous posts.

eps = np.finfo(float).eps
train_data_red_df = train_data_df.query("cohort_age > 0").reset_index(drop=True)
train_obs_idx = train_data_red_df.index.to_numpy()
train_n_users = train_data_red_df["n_users"].to_numpy()
train_n_active_users = train_data_red_df["n_active_users"].to_numpy()
train_retention = train_data_red_df["retention"].to_numpy()
train_retention_logit = logit(train_retention + eps)
train_data_red_df["month"] = train_data_red_df["period"].dt.strftime("%m").astype(int)
train_data_red_df["cohort_month"] = (
    train_data_red_df["cohort"].dt.strftime("%m").astype(int)
)
train_data_red_df["period_month"] = (
    train_data_red_df["period"].dt.strftime("%m").astype(int)
)
train_revenue = train_data_red_df["revenue"].to_numpy() + eps
train_revenue_per_user = train_revenue / (train_n_active_users + eps)

train_cohort = train_data_red_df["cohort"].to_numpy()
train_cohort_encoder = LabelEncoder()
train_cohort_idx = train_cohort_encoder.fit_transform(train_cohort).flatten()
train_period = train_data_red_df["period"].to_numpy()
train_period_encoder = LabelEncoder()
train_period_idx = train_period_encoder.fit_transform(train_period).flatten()

features: list[str] = ["age", "cohort_age", "month"]
x_train = train_data_red_df[features]

train_age = train_data_red_df["age"].to_numpy()
train_age_scaler = MaxAbsScaler()
train_age_scaled = train_age_scaler.fit_transform(train_age.reshape(-1, 1)).flatten()
train_cohort_age = train_data_red_df["cohort_age"].to_numpy()
train_cohort_age_scaler = MaxAbsScaler()
train_cohort_age_scaled = train_cohort_age_scaler.fit_transform(
    train_cohort_age.reshape(-1, 1)
).flatten()

For the variables entering into the model we need to convert them to jnp.array objects.

train_n_users = jnp.array(train_n_users)
train_n_active_users = jnp.array(train_n_active_users)
train_revenue = jnp.array(train_revenue)

Moreover, for the design matrix feeding the neural network we also need to scale the features (Note this was not necessary with the BART model). We also one-hot-encode the month variable.

numerical_features = ["age", "cohort_age"]
categorical_features = ["month"]

numerical_transformer = Pipeline(steps=[("scaler", StandardScaler())])
categorical_features_transformer = Pipeline(
    steps=[("onehot", OneHotEncoder(drop="first", sparse_output=False))]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numerical_transformer, numerical_features),
        ("cat", categorical_features_transformer, categorical_features),
    ]
).set_output(transform="pandas")

preprocessor.fit(x_train)
x_train_preprocessed = preprocessor.transform(x_train)

x_train_preprocessed.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 595 entries, 0 to 594
Data columns (total 13 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   num__age         595 non-null    float64
 1   num__cohort_age  595 non-null    float64
 2   cat__month_2     595 non-null    float64
 3   cat__month_3     595 non-null    float64
 4   cat__month_4     595 non-null    float64
 5   cat__month_5     595 non-null    float64
 6   cat__month_6     595 non-null    float64
 7   cat__month_7     595 non-null    float64
 8   cat__month_8     595 non-null    float64
 9   cat__month_9     595 non-null    float64
 10  cat__month_10    595 non-null    float64
 11  cat__month_11    595 non-null    float64
 12  cat__month_12    595 non-null    float64
dtypes: float64(13)
memory usage: 60.6 KB

Finally, we convert the transformed design matrix to a jnp.array object.

x_train_preprocessed_array = jnp.array(x_train_preprocessed)

Model Specification

We are now ready to write the model. First we define the neural network architecture in Flax.

class RetentionMLP(nn.Module):
    layers: list[int]

    @nn.compact
    def __call__(self, x):
        for num_features in self.layers:
            x = nn.sigmoid(nn.Dense(features=num_features)(x))
        return x

Now we write the NumPyro model. For the weights and biases we use set Normal priors centered around zero with unit variance.

def model(x, age, cohort_age, n_users, revenue=None, n_active_users=None):
    eps = np.finfo(float).eps

    retention_nn = random_flax_module(
        "retention_nn",
        RetentionMLP(layers=[4, 1]),
        prior=dist.Laplace(loc=0, scale=1),
        input_shape=(x.shape[1],),
    )
    retention = numpyro.deterministic("retention", retention_nn(x).squeeze(-1))

    intercept = numpyro.sample("intercept", dist.Normal(loc=0, scale=1))
    b_age = numpyro.sample("b_age", dist.Normal(loc=0, scale=1))
    b_cohort_age = numpyro.sample("b_cohort_age", dist.Normal(loc=0, scale=1))
    b_interaction = numpyro.sample("b_interaction", dist.Normal(loc=0, scale=1))

    lam_log = numpyro.deterministic(
        "lam_log",
        intercept
        + b_age * age
        + b_cohort_age * cohort_age
        + b_interaction * age * cohort_age,
    )

    lam = numpyro.deterministic("lam", jnp.exp(lam_log))

    with numpyro.plate("data", len(x)):
        n_active_users_estimated = numpyro.sample(
            "n_active_users_estimated",
            dist.Binomial(total_count=n_users, probs=retention),
            obs=n_active_users,
        )

        numpyro.deterministic("retention_estimated", n_active_users_estimated / n_users)

        numpyro.sample(
            "revenue_estimated",
            dist.Gamma(concentration=n_active_users_estimated + eps, rate=lam),
            obs=revenue,
        )

We can visualize the model using structure.

numpyro.render_model(
    model=model,
    model_kwargs={
        "x": x_train_preprocessed_array,
        "age": train_age_scaled,
        "cohort_age": train_cohort_age_scaled,
        "n_users": train_n_users,
        "revenue": train_revenue,
        "n_active_users": train_n_active_users,
    },
    render_distributions=True,
    render_params=True,
)

Inference: SVI

We use stochastic variational inference (SVI) to fit the model. We also se a AutoNormal guide as a variational distribution.

guide = AutoNormal(model=model)
optimizer = numpyro.optim.Adam(step_size=0.02)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_samples = 15_000
rng_key, rng_subkey = random.split(key=rng_key)
svi_result = svi.run(
    rng_subkey,
    n_samples,
    x_train_preprocessed_array,
    train_age_scaled,
    train_cohort_age_scaled,
    train_n_users,
    train_revenue,
    train_n_active_users,
)

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set(yscale="log")
ax.set_title("ELBO loss", fontsize=18, fontweight="bold")

We see a nicely decaying ELBO curve.

Next we sample from the posterior distribution.

params = svi_result.params

posterior_predictive = Predictive(
    model=model, guide=guide, params=params, num_samples=4 * 2_000
)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_predictive_samples = posterior_predictive(
    rng_subkey,
    x_train_preprocessed_array,
    train_age_scaled,
    train_cohort_age_scaled,
    train_n_users,
)

We structure the samples as an az.InferenceData object.

idata_svi = az.from_dict(
    posterior_predictive={
        k: np.expand_dims(a=np.asarray(v), axis=0)
        for k, v in posterior_predictive_samples.items()
    },
    coords={"obs_idx": train_obs_idx},
    dims={
        "retention": ["obs_idx"],
        "n_active_users_estimated": ["obs_idx"],
        "retention_estimated": ["obs_idx"],
        "revenue_estimated": ["obs_idx"],
    },
)

In-Sample Predictions

We now look at the in-sample fit. First we start with retention:

fig, ax = plt.subplots(figsize=(9, 7))
sns.scatterplot(
    x=train_retention,
    y=idata_svi["posterior_predictive"]["retention_estimated"].mean(
        dim=["chain", "draw"]
    ),
    color="C0",
    label="Mean Predicted Retention",
    ax=ax,
)
az.plot_hdi(
    x=train_retention,
    y=idata_svi["posterior_predictive"]["retention_estimated"],
    hdi_prob=0.94,
    color="C0",
    fill_kwargs={"alpha": 0.2, "label": "$94\\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=train_retention,
    y=idata_svi["posterior_predictive"]["retention_estimated"],
    hdi_prob=0.5,
    color="C0",
    fill_kwargs={"alpha": 0.4, "label": "$50\\%$ HDI"},
    ax=ax,
)
ax.axline((0, 0), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend(loc="upper left")
ax.set(xlabel="True Retention", ylabel="Predicted Retention")
ax.set_title(
    label="True vs Predicted Retention (Train)", fontsize=18, fontweight="bold"
)

Overall it looks very reasonable. There are some point the model does not properly catch, but the previous BART model had the same issue.

The revenue in-sample predictions are also very similar to the ones obtained with the BART model.

fig, ax = plt.subplots(figsize=(9, 7))
sns.scatterplot(
    x=train_revenue,
    y=idata_svi["posterior_predictive"]["revenue_estimated"].mean(
        dim=["chain", "draw"]
    ),
    color="C0",
    label="Mean Predicted revenue_estimated",
    ax=ax,
)
az.plot_hdi(
    x=train_revenue,
    y=idata_svi["posterior_predictive"]["revenue_estimated"],
    hdi_prob=0.94,
    color="C0",
    fill_kwargs={"alpha": 0.2, "label": "$94\\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=train_revenue,
    y=idata_svi["posterior_predictive"]["revenue_estimated"],
    hdi_prob=0.5,
    color="C0",
    fill_kwargs={"alpha": 0.4, "label": "$50\\%$ HDI"},
    ax=ax,
)
ax.axline((0, 0), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend(loc="upper left")
ax.set(xlabel="True Revenue", ylabel="Predicted Revenue")
ax.set_title(label="True vs Predicted Revenue (Train)", fontsize=18, fontweight="bold")

As in the previous post, we select a couple of example cohorts to visualize the predictions:

  • Retention
train_retention_hdi = az.hdi(ary=idata_svi["posterior_predictive"])[
    "retention_estimated"
]


def plot_train_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = train_cohort_idx == cohort_index

    ax.fill_between(
        x=train_period[train_period_idx[mask]],
        y1=train_retention_hdi[mask, :][:, 0],
        y2=train_retention_hdi[mask, :][:, 1],
        alpha=0.2,
        color="C3",
        label="94% HDI",
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=idata_svi["posterior_predictive"]["retention_estimated"].mean(
            dim=["chain", "draw"]
        )[mask],
        marker="o",
        color="C3",
        label="predicted",
        ax=ax,
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=train_retention[mask],
        color="C0",
        marker="o",
        label="observed",
        ax=ax,
    )
    cohort_name = (
        pd.to_datetime(train_cohort_encoder.classes_[cohort_index]).date().isoformat()
    )
    ax.legend(loc="upper left")
    ax.set(title=f"Retention - Cohort {cohort_name}")
    return ax


cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=np.ceil(len(cohort_index_to_plot) / 2).astype(int),
    ncols=2,
    figsize=(17, 11),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten(), strict=True):
    plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)

fig.suptitle("In-Sample Retention", y=1.03, fontsize=20, fontweight="bold")
fig.autofmt_xdate()
  • Revenue
train_revenue_estimated_hdi = az.hdi(ary=idata_svi["posterior_predictive"])[
    "revenue_estimated"
]


def plot_train_revenue_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = train_cohort_idx == cohort_index

    ax.fill_between(
        x=train_period[train_period_idx[mask]],
        y1=train_revenue_estimated_hdi[mask, :][:, 0],
        y2=train_revenue_estimated_hdi[mask, :][:, 1],
        alpha=0.2,
        color="C3",
        label="94% HDI",
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=idata_svi["posterior_predictive"]["revenue_estimated"].mean(
            dim=["chain", "draw"]
        )[mask],
        marker="o",
        color="C3",
        label="predicted",
        ax=ax,
    )
    sns.lineplot(
        x=train_period[train_period_idx[mask]],
        y=train_revenue[mask],
        color="C0",
        marker="o",
        label="observed",
        ax=ax,
    )
    cohort_name = (
        pd.to_datetime(train_cohort_encoder.classes_[cohort_index]).date().isoformat()
    )
    ax.legend(loc="upper left")
    ax.set(title=f"revenue_estimated - Cohort {cohort_name}")
    return ax


cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=np.ceil(len(cohort_index_to_plot) / 2).astype(int),
    ncols=2,
    figsize=(17, 11),
    sharex=True,
    sharey=False,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten(), strict=True):
    plot_train_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)

fig.suptitle("In-Sample Revenue", y=1.03, fontsize=20, fontweight="bold")
fig.autofmt_xdate()

The fit look very good and are very similar to the ones obtained with the BART model! Note that the HDI intervals are wider for smaller cohorts as expected.

Out-of-Sample Predictions

Now we prepare the data for the out-of-sample predictions.

test_data_red_df = test_data_df.query("cohort_age > 0")
test_data_red_df = test_data_red_df[
    test_data_red_df["cohort"].isin(train_data_red_df["cohort"].unique())
].reset_index(drop=True)
test_obs_idx = test_data_red_df.index.to_numpy()
test_n_users = test_data_red_df["n_users"].to_numpy()
test_n_active_users = test_data_red_df["n_active_users"].to_numpy()
test_retention = test_data_red_df["retention"].to_numpy()
test_revenue = test_data_red_df["revenue"].to_numpy()

test_cohort = test_data_red_df["cohort"].to_numpy()
test_cohort_idx = train_cohort_encoder.transform(test_cohort).flatten()

test_data_red_df["month"] = test_data_red_df["period"].dt.strftime("%m").astype(int)
test_data_red_df["cohort_month"] = (
    test_data_red_df["cohort"].dt.strftime("%m").astype(int)
)
test_data_red_df["period_month"] = (
    test_data_red_df["period"].dt.strftime("%m").astype(int)
)
x_test = test_data_red_df[features]

test_age = test_data_red_df["age"].to_numpy()
test_age_scaled = train_age_scaler.transform(test_age.reshape(-1, 1)).flatten()
test_cohort_age = test_data_red_df["cohort_age"].to_numpy()
test_cohort_age_scaled = train_cohort_age_scaler.transform(
    test_cohort_age.reshape(-1, 1)
).flatten()
test_n_users = jnp.array(test_n_users)
test_n_active_users = jnp.array(test_n_active_users)
test_revenue = jnp.array(test_revenue)

x_test_preprocessed = preprocessor.transform(x_test)
x_test_preprocessed_array = jnp.array(x_test_preprocessed)

Now we are ready to sample from the posterior distribution.

test_predictive = Predictive(
    model=model, guide=guide, params=params, num_samples=4 * 2_000
)
rng_key, rng_subkey = random.split(key=rng_key)
test_posterior_predictive_samples = test_predictive(
    rng_subkey,
    x_test_preprocessed_array,
    test_age_scaled,
    test_cohort_age_scaled,
    test_n_users,
)

Again, we structure the samples as an az.InferenceData object.

test_idata_svi = az.from_dict(
    posterior_predictive={
        k: np.expand_dims(a=np.asarray(v), axis=0)
        for k, v in test_posterior_predictive_samples.items()
    },
    coords={"obs_idx": test_obs_idx},
    dims={
        "retention": ["obs_idx"],
        "n_active_users_estimates": ["obs_idx"],
        "revenue_estimated": ["obs_idx"],
    },
)

Let’s now look at the actual against the predictions.

  • Retention
fig, ax = plt.subplots(figsize=(9, 7))
sns.scatterplot(
    x=test_retention,
    y=test_idata_svi["posterior_predictive"]["retention_estimated"].mean(
        dim=["chain", "draw"]
    ),
    color="C1",
    label="Mean Predicted Retention",
    ax=ax,
)
az.plot_hdi(
    x=test_retention,
    y=test_idata_svi["posterior_predictive"]["retention_estimated"],
    hdi_prob=0.94,
    color="C1",
    fill_kwargs={"alpha": 0.2, "label": "$94\\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=test_retention,
    y=test_idata_svi["posterior_predictive"]["retention_estimated"],
    hdi_prob=0.5,
    color="C1",
    fill_kwargs={"alpha": 0.4, "label": "$50\\%$ HDI"},
    ax=ax,
)
ax.axline((0, 0), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend(loc="upper left")
ax.set(xlabel="True Retention", ylabel="Predicted Retention")
ax.set_title(label="True vs Predicted Retention (test)", fontsize=18, fontweight="bold")
  • Revenue
fig, ax = plt.subplots(figsize=(9, 7))
sns.scatterplot(
    x=test_revenue,
    y=test_idata_svi["posterior_predictive"]["revenue_estimated"].mean(
        dim=["chain", "draw"]
    ),
    color="C1",
    label="Mean Predicted revenue_estimated",
    ax=ax,
)
az.plot_hdi(
    x=test_revenue,
    y=test_idata_svi["posterior_predictive"]["revenue_estimated"],
    hdi_prob=0.94,
    color="C1",
    fill_kwargs={"alpha": 0.2, "label": "$94\\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=test_revenue,
    y=test_idata_svi["posterior_predictive"]["revenue_estimated"],
    hdi_prob=0.5,
    color="C1",
    fill_kwargs={"alpha": 0.4, "label": "$50\\%$ HDI"},
    ax=ax,
)
ax.axline((0, 0), slope=1, color="black", linestyle="--", label="diagonal")
ax.legend(loc="upper left")
ax.set(xlabel="True revenue_estimated", ylabel="Predicted revenue_estimated")
ax.set_title(label="True vs Predicted Revenue (test)", fontsize=18, fontweight="bold")

The retention predictions look very good whereas the revenue ones seem a bit higher for cohorts with higher revenue. This is also similar to the BART model.

Finally, lets get out of sample predictions for the selected subset of cohorts.

  • Retention
test_retention_hdi = az.hdi(ary=test_idata_svi["posterior_predictive"])[
    "retention_estimated"
]


def plot_test_retention_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = test_cohort_idx == cohort_index

    test_period_range = test_data_red_df.query(
        f"cohort == '{train_cohort_encoder.classes_[cohort_index]}'"
    )["period"]

    ax.fill_between(
        x=test_period_range,
        y1=test_retention_hdi[mask, :][:, 0],
        y2=test_retention_hdi[mask, :][:, 1],
        alpha=0.2,
        color="C3",
    )

    sns.lineplot(
        x=test_period_range,
        y=test_idata_svi["posterior_predictive"]["retention_estimated"].mean(
            dim=["chain", "draw"]
        )[mask],
        marker="o",
        color="C3",
        ax=ax,
    )
    sns.lineplot(
        x=test_period_range,
        y=test_retention[mask],
        color="C1",
        marker="o",
        label="observed (test)",
        ax=ax,
    )
    return ax
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=len(cohort_index_to_plot),
    ncols=1,
    figsize=(15, 16),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten(), strict=True):
    plot_train_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
    plot_test_retention_hdi_cohort(cohort_index=cohort_index, ax=ax)
    ax.axvline(
        x=pd.to_datetime(period_train_test_split),
        color="black",
        linestyle="--",
        label="train/test split",
    )
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
fig.suptitle("Retention Predictions", y=1.03, fontsize=20, fontweight="bold")
  • Revenue
test_revenue_estimated_hdi = az.hdi(ary=test_idata_svi["posterior_predictive"])[
    "revenue_estimated"
]


def plot_test_revenue_hdi_cohort(cohort_index: int, ax: plt.Axes) -> plt.Axes:
    mask = test_cohort_idx == cohort_index

    test_period_range = test_data_red_df.query(
        f"cohort == '{train_cohort_encoder.classes_[cohort_index]}'"
    )["period"]

    ax.fill_between(
        x=test_period_range,
        y1=test_revenue_estimated_hdi[mask, :][:, 0],
        y2=test_revenue_estimated_hdi[mask, :][:, 1],
        alpha=0.2,
        color="C3",
    )

    sns.lineplot(
        x=test_period_range,
        y=test_idata_svi["posterior_predictive"]["revenue_estimated"].mean(
            dim=["chain", "draw"]
        )[mask],
        marker="o",
        color="C3",
        ax=ax,
    )
    sns.lineplot(
        x=test_period_range,
        y=test_revenue[mask],
        color="C1",
        marker="o",
        label="observed (test)",
        ax=ax,
    )
    return ax
cohort_index_to_plot = [0, 1, 5, 10, 15, 20, 25, 30]

fig, axes = plt.subplots(
    nrows=len(cohort_index_to_plot),
    ncols=1,
    figsize=(15, 16),
    sharex=True,
    sharey=False,
    layout="constrained",
)

for cohort_index, ax in zip(cohort_index_to_plot, axes.flatten(), strict=True):
    plot_train_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)
    plot_test_revenue_hdi_cohort(cohort_index=cohort_index, ax=ax)
    ax.axvline(
        x=pd.to_datetime(period_train_test_split),
        color="black",
        linestyle="--",
        label="train/test split",
    )
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
fig.suptitle("Revenue Predictions", y=1.03, fontsize=20, fontweight="bold")

The results look very good! Both the posterior mean and the uncertainty as a function of the cohort size.

This extension of the model allow us for greater flexibility without compromising the inference speed (it is actually much faster) or accuracy. We could actually iterate on the model development with SVI and then get posterior samples via full HMC (NUTS). This is also fast for the typical dataset sizes we have in this type of problems.