7 min read

Flax and NumPyro Toy Example

In this notebook I want to experiment with the numpyro/contrib/module.py module which allow us to integrate Flax models with NumPyro models. I am interested in this because I want to experiment with complex bayesian models with larger datasets.

Most of the main components can be found in the great blog post Bayesian Neural Networks with Flax and Numpyro. The author takes a different path working directly with potentials, but he also points out the recent addition of the numpyro/contrib/module.py module. The main difference with the model presented here is that I am using two components in the model (to model the mean and standard deviation of the data), I use stochastic variational inference instead of MCMC and I work with scaling transformations.

Another great source to learn about this model and expected behavior are the unit tests of the numpyro/contrib/module.py module 😄.

Remark: Note that there is a way of adding Flax models into PyMC, see How to use JAX ODEs and Neural Networks in PyMC.

Prepare Notebook

import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import seaborn as sns
import xarray as xr
from flax import linen as nn
from jax import random
from numpyro.contrib.module import flax_module
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import Predictive
from sklearn.preprocessing import StandardScaler

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

Generate Synthetic Data

We generate a simple one-dimensional dataset with a non-linear relationship between the input and output.

n = 32 * 10
rng_key, rng_subkey = random.split(rng_key)
x = random.uniform(rng_key, shape=(n,), minval=1, maxval=jnp.pi)
mu_true = jnp.sqrt(x + 0.5) * jnp.sin(9 * x)
sigma_true = 0.15 * x**2
rng_key, rng_subkey = random.split(rng_key)
y = mu_true + sigma_true * random.normal(rng_key, shape=(n,))

Note that we are actually adding non-linearities to the mean and standard deviation of the data.

x_idx = jnp.argsort(x)

fig, ax = plt.subplots()
sns.lineplot(x=x, y=mu_true, color="C0", label=r"$\mu$", linewidth=3, ax=ax)
ax.fill_between(
    x[x_idx],
    (mu_true - 2 * sigma_true)[x_idx],
    (mu_true + 2 * sigma_true)[x_idx],
    color="C0",
    alpha=0.2,
    label=r"$\mu \pm 2 \sigma$",
)
sns.scatterplot(x=x, y=y, color="black", label="data", ax=ax)
ax.legend(loc="upper left")
ax.set_title(label="Simulated Data", fontsize=18, fontweight="bold")
ax.set(xlabel="x", ylabel="y")

Train-Test Split

We now do a simple train-test split of the data.

train_test_split = 0.7
train_idx = int(n * train_test_split)
x_train, y_train = x[:train_idx], y[:train_idx]
x_test, y_test = x[train_idx:], y[train_idx:]
# useful variables for indexing
obs_train = jnp.arange(x_train.size)
obs_test = jnp.arange(x_test.size) + x_train.size

Let’s see both datasets:

fig, ax = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=True, figsize=(12, 9), layout="constrained"
)
sns.scatterplot(x=x_train, y=y_train, color="C0", label="train", ax=ax[0])
ax[0].legend(loc="upper left")
ax[0].set_title(label="Train Data", fontsize=18, fontweight="bold")
ax[0].set(xlabel="x", ylabel="y")
sns.scatterplot(x=x_test, y=y_test, color="C1", label="test", ax=ax[1])
ax[1].legend(loc="upper left")
ax[1].set_title(label="Test Data", fontsize=18, fontweight="bold")
ax[1].set(xlabel="x", ylabel="y")

Data Preprocessing

As we want to use a neural network to model the mean and standard deviation of the data, we need to scale the data. We use a StandardScaler from sklearn to do this.

x_scaler = StandardScaler()
x_train_scaled = x_scaler.fit_transform(x_train[:, None])
x_train_scaled = jnp.array(x_train_scaled)
x_test_scaled = x_scaler.transform(x_test[:, None])
x_test_scaled = jnp.array(x_test_scaled)

y_scaler = StandardScaler()
y_train_scaled = y_scaler.fit_transform(y_train[:, None])
y_train_scaled = y_train_scaled.squeeze()
y_train_scaled = jnp.array(y_train_scaled)
y_test_scaled = y_scaler.transform(y_test[:, None])
y_test_scaled = y_test_scaled.squeeze()
y_test_scaled = jnp.array(y_test_scaled)

Remark: Note that the feature matrices x_train and x_test have dimension \(2\) as the first dimension represent the batch component! For the target variables we simply have a vector.

fig, ax = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=True, figsize=(12, 9), layout="constrained"
)
sns.scatterplot(
    x=x_train_scaled.squeeze(),
    y=y_train_scaled,
    color="C0",
    label="train",
    ax=ax[0],
)
ax[0].legend(loc="upper left")
ax[0].set_title(label="Train Data (Scaled)", fontsize=18, fontweight="bold")
ax[0].set(xlabel="x", ylabel="y")
sns.scatterplot(
    x=x_test_scaled.squeeze(),
    y=y_test_scaled,
    color="C1",
    label="test",
    ax=ax[1],
)
ax[1].legend(loc="upper left")
ax[1].set_title(label="Test Data (Scaled)", fontsize=18, fontweight="bold")
ax[1].set(xlabel="x", ylabel="y")

Model Specification

We now specify the model. We use a simple feed-forward neural network or multilayer perceptron (MLP) to model both the mean and standard deviation of the data. We then pass it through Normal likelihoods distribution in a NumPyro model.

First, we define the MLP with Flax:

class MLP(nn.Module):
    layers: list[int]

    @nn.compact
    def __call__(self, x):
        for num_features in self.layers[:-1]:
            x = nn.sigmoid(nn.Dense(features=num_features)(x))
        return nn.Dense(features=self.layers[-1])(x)

We now specify the NumPyro model. The main ingredient is the from flax_module from the numpyro.contrib.module module. This allows us to integrate the Flax model into the NumPyro model.

def model(x, y=None):
    mu_nn = flax_module("mu_nn", MLP(layers=[4, 4, 1]), input_shape=(1,))
    log_sigma_nn = flax_module("sigma_nn", MLP(layers=[2, 1]), input_shape=(1,))

    mu = numpyro.deterministic("mu", mu_nn(x).squeeze())
    sigma = numpyro.deterministic("sigma", jnp.exp(log_sigma_nn(x)).squeeze())

    with numpyro.plate("data", len(x)):
        numpyro.sample("likelihood", dist.Normal(loc=mu, scale=sigma), obs=y)


numpyro.render_model(
    model=model,
    model_args=(x_train_scaled, y_train_scaled),
    render_distributions=True,
    render_params=True,
)

Remark: The flax_module function consider the Flax model parameters as parameters to lear, but not really distributions. We can add priors to the parameters of the Flax model using the random_flax_module instead. We could then pass the priors as a dictionary:

mu_nn = flax_module(
    "mu_nn",
    MLP(layers=[4, 4, 1]),
    prior={
        "Dense_0.bias": dist.Cauchy(),
        "Dense_0.kernel": dist.Normal(),
        "Dense_1.bias": dist.Cauchy(),
        "Dense_1.kernel": dist.Normal(),
        "Dense_2.bias": dist.Cauchy(),
        "Dense_2.kernel": dist.Normal(),
    },
    input_shape=(1,),
)

Model Fitting

We now fit the model. We use the SVI class from NumPyro to do this.

guide = AutoNormal(model=model)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_samples = 5_000
rng_key, rng_subkey = random.split(key=rng_key)
svi_result = svi.run(rng_subkey, n_samples, x_train_scaled, y_train_scaled)

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set_title("ELBO loss", fontsize=18, fontweight="bold")

Posterior Predictive Checks

We now generate posterior predictive samples from the model for the training and test data.

First we generate posterior samples from the model in the scaled space.

params = svi_result.params
# get posterior predictive (deterministics and likelihood)
posterior_predictive = Predictive(
    model=model, guide=guide, params=params, num_samples=4 * 4_000
)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_predictive_samples = posterior_predictive(rng_subkey, x_train_scaled)

We store the posterior samples in a az.InferenceData object.

idata_svi = az.from_dict(
    posterior_predictive={
        k: np.expand_dims(a=np.asarray(v), axis=0)
        for k, v in posterior_predictive_samples.items()
    },
    coords={"obs": obs_train},
    dims={"mu": ["obs"], "sigma": ["obs"], "likelihood": ["obs"]},
)

We now scale the posterior samples back to the original space.

posterior_predictive_original_scale = {
    var_name: xr.apply_ufunc(
        y_scaler.inverse_transform,
        idata_svi["posterior_predictive"][var_name].expand_dims(dim={"_": 1}, axis=-1),
        input_core_dims=[["obs", "_"]],
        output_core_dims=[["obs", "_"]],
        vectorize=True,
    ).squeeze(dim="_")
    for var_name in ["mu", "sigma", "likelihood"]
}

Let’s see the posterior predictive checks for the training data:

fig, ax = plt.subplots()
az.plot_hdi(
    x=x_train,
    y=posterior_predictive_original_scale["likelihood"],
    hdi_prob=0.94,
    color="C0",
    smooth=False,
    fill_kwargs={"label": r"likelihood $50\%$ HDI", "alpha": 0.2},
    ax=ax,
)
az.plot_hdi(
    x=x_train,
    y=posterior_predictive_original_scale["likelihood"],
    hdi_prob=0.50,
    color="C0",
    fill_kwargs={"label": r"likelihood $94\%$ HDI", "alpha": 0.3},
    smooth=False,
    ax=ax,
)

sns.lineplot(
    x=x_train,
    y=posterior_predictive_original_scale["mu"].mean(dim=["chain", "draw"]),
    color="C0",
    linewidth=3,
    label=r"$\mu$",
    ax=ax,
)

sns.scatterplot(x=x_train, y=y_train, color="black", label="train", ax=ax)
sns.lineplot(
    x=x, y=mu_true, color="C3", label=r"$\mu_{\text{true}}$", linewidth=3, ax=ax
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=5)
ax.set_title(label="Train Data - Posterior Predictive", fontsize=18, fontweight="bold")
ax.set(xlabel="x", ylabel="y")

The model seems to capture the mean and standard deviation of the data well. Still, not great at the right boundary.

We can run the same procedure for the test data:

predictive = Predictive(model=model, guide=guide, params=params, num_samples=4 * 4_000)
rng_key, rng_subkey = random.split(key=rng_key)
test_posterior_predictive_samples = predictive(rng_subkey, x_test_scaled)

test_idata_svi = az.from_dict(
    posterior_predictive={
        k: np.expand_dims(a=np.asarray(v), axis=0)
        for k, v in test_posterior_predictive_samples.items()
    },
    coords={"obs": obs_test},
    dims={"mu": ["obs"], "sigma": ["obs"], "likelihood": ["obs"]},
)

test_posterior_predictive_original_scale = {
    var_name: xr.apply_ufunc(
        y_scaler.inverse_transform,
        test_idata_svi["posterior_predictive"][var_name].expand_dims(
            dim={"_": 1}, axis=-1
        ),
        input_core_dims=[["obs", "_"]],
        output_core_dims=[["obs", "_"]],
        vectorize=True,
    ).squeeze(dim="_")
    for var_name in ["mu", "sigma", "likelihood"]
}

We generate the posterior predictive checks for the test data:

fig, ax = plt.subplots()
az.plot_hdi(
    x=x_test,
    y=test_posterior_predictive_original_scale["likelihood"],
    hdi_prob=0.94,
    color="C1",
    smooth=False,
    fill_kwargs={"label": r"likelihood $50\%$ HDI", "alpha": 0.2},
    ax=ax,
)
az.plot_hdi(
    x=x_test,
    y=test_posterior_predictive_original_scale["likelihood"],
    hdi_prob=0.50,
    color="C1",
    fill_kwargs={"label": r"likelihood $94\%$ HDI", "alpha": 0.3},
    smooth=False,
    ax=ax,
)

sns.lineplot(
    x=x_test,
    y=test_posterior_predictive_original_scale["mu"].mean(dim=["chain", "draw"]),
    color="C1",
    linewidth=3,
    label=r"$\mu$",
    ax=ax,
)

sns.scatterplot(
    x=x_test,
    y=y_test,
    color="black",
    label="test",
    ax=ax,
)
sns.lineplot(
    x=x, y=mu_true, color="C3", label=r"$\mu_{\text{true}}$", linewidth=3, ax=ax
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=5)
ax.set_title(label="Test Data - Posterior Predictive", fontsize=18, fontweight="bold")
ax.set(xlabel="x", ylabel="y")

Overall, the results seem very reasonable. The right boundary is still not great. In general we do not expect these models to capture patterns outside the range of the training data.

Even though the data and the model are very simple, this just serves as a concrete example for more complex scenarios were complex feature interactions trough neural networks can be beneficial in the context of bayesian modelling.