8 min read

From Pyro to NumPyro: Forecasting Hierarchical Models - Part II

In this second notebook, we continue working on the NumPyro implementation of the hierarchical forecasting models presented in Pyro’s forecasting documentation: Forecasting III: hierarchical models. In this second part, we extend the model described in the first part From Pyro to NumPyro: Forecasting Hierarchical Models - Part I by adding all stations to the model.

Prepare Notebook

import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from jax import random
from jaxtyping import Array, Float
from numpyro.contrib.control_flow import scan
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.reparam import LocScaleReparam
from pyro.contrib.examples.bart import load_bart_od

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

%load_ext autoreload
%autoreload 2
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype
%config InlineBackend.figure_format = "retina"
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The jaxtyping extension is already loaded. To reload it, use:
  %reload_ext jaxtyping

Read Data

dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))
dict_keys(['stations', 'start_date', 'counts'])
torch.Size([78888, 50, 50])
12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK

In this second example, we model all the rides from all stations to all other stations.

data = dataset["counts"].permute(1, 2, 0).log1p()
T = data.shape[-2]
print(data.shape)
torch.Size([50, 50, 78888])

Train - Test Split

Similarly as in the first example, for training purposes we will use data from 90 days before the test data.

T2 = data.size(-1)  # end
T1 = T2 - 24 * 7 * 2  # train/test split
T0 = T1 - 24 * 90  # beginning: train on 90 days of data
y = jnp.array(data[..., T0:T2])
y_train = jnp.array(data[..., T0:T1])
y_test = jnp.array(data[..., T1:T2])

print(f"y: {y.shape}")
print(f"y_train: {y_train.shape}")
print(f"y_test: {y_test.shape}")
y: (50, 50, 2496)
y_train: (50, 50, 2160)
y_test: (50, 50, 336)
n_stations = y_train.shape[-2]

time = jnp.array(range(T0, T2))
time_train = jnp.array(range(T0, T1))
t_max_train = time_train.size

time_test = jnp.array(range(T1, T2))
t_max_test = time_test.size

covariates = jnp.zeros_like(y)
covariates_train = jnp.zeros_like(y_train)
covariates_test = jnp.zeros_like(y_test)

assert time_train.size + time_test.size == time.size
assert y_train.shape == (n_stations, n_stations, t_max_train)
assert y_test.shape == (n_stations, n_stations, t_max_test)
assert covariates.shape == y.shape
assert covariates_train.shape == y_train.shape
assert covariates_test.shape == y_test.shape

Repeating Seasonal Features

We also need the JAX version of the periodic_repeat function.

def periodic_repeat_jax(tensor: Array, size: int, dim: int) -> Array:
    """
    Repeat a period-sized tensor up to given size using JAX.

    Parameters
    ----------
    tensor : Array
        A JAX array to be repeated.
    size : int
        Desired size of the result along dimension `dim`.
    dim : int
        The tensor dimension along which to repeat.

    Returns
    -------
    Array
        The repeated tensor.

    References
    ----------
    Thttps://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat
    """
    assert isinstance(size, int) and size >= 0
    assert isinstance(dim, int)
    if dim >= 0:
        dim -= tensor.ndim

    period = tensor.shape[dim]
    repeats = [1] * tensor.ndim
    repeats[dim] = (size + period - 1) // period
    result = jnp.tile(tensor, repeats)

    slices = [slice(None)] * tensor.ndim
    slices[dim] = slice(None, size)

    return result[tuple(slices)]

Model Specification

In this model, the local level dynamic is driven by the destination station. On the other hand, the seasonal components and the noise scales come as a sum of the origin and destination stations. The model structure is very similar to the one presented in the first example.

def model(
    covariates: Float[Array, "n_series n_series t_max"],
    y: Float[Array, "n_series n_series t_max"] | None = None,
) -> None:
    # Get the time and feature dimensions
    n_series, n_series, t_max = covariates.shape

    # Define the plates to be able to use them below
    origin_plate = numpyro.plate("origin", n_series, dim=-3)
    destin_plate = numpyro.plate("destin", n_series, dim=-2)
    hour_of_week_plate = numpyro.plate("hour_of_week", 24 * 7, dim=-1)

    # Global scale for the drift
    drift_scale = numpyro.sample("drift_scale", dist.LogNormal(loc=-20, scale=5))

    # Sample the centered parameter for the LocScaleReparam
    destin_centered = numpyro.sample("destin_centered", dist.Uniform(low=0, high=1))

    with origin_plate, hour_of_week_plate:
        origin_seasonal = numpyro.sample("origin_seasonal", dist.Normal(loc=0, scale=5))

    with destin_plate:
        with (
            numpyro.plate("time", t_max),
            numpyro.handlers.reparam(
                config={"drift": LocScaleReparam(centered=destin_centered)}
            ),
        ):
            # Sample the drift parameters
            # We have one drift parameter per time series (station) and time point
            drift = numpyro.sample("drift", dist.Normal(loc=0, scale=drift_scale))

        with hour_of_week_plate:
            # Sample the seasonal parameters
            # We have one seasonal parameter per hour of the week and per station
            destin_seasonal = numpyro.sample(
                "destin_seasonal", dist.Normal(loc=0, scale=5)
            )

    # We model a static pairwise station->station affinity, which e.g.
    # can compensate for the fact that people tend not to travel from
    # a station to itself.
    with origin_plate, destin_plate:
        pairwise = numpyro.sample("pairwise", dist.Normal(0, 1))

    # We model the origin and destination scales separately
    # and then add them together to get the final scale.
    with origin_plate:
        origin_scale = numpyro.sample("origin_scale", dist.LogNormal(-5, 5))
    with destin_plate:
        destin_scale = numpyro.sample("destin_scale", dist.LogNormal(-5, 5))
        scale = origin_scale + destin_scale

    # Repeat the seasonal parameters to match the length of the time series
    seasonal = origin_seasonal + destin_seasonal
    seasonal_repeat = periodic_repeat_jax(seasonal, t_max, dim=-1)

    # Define the local level transition function
    def transition_fn(carry, t):
        "Local level transition function"
        previous_level = carry
        current_level = previous_level + drift[..., t]
        return current_level, current_level

    # Compute the latent levels using scan
    _, pred_levels = scan(
        transition_fn, init=jnp.zeros((n_series,)), xs=jnp.arange(t_max)
    )

    # We need to transpose the prediction levels to match the shape of the data
    pred_levels = pred_levels.transpose(1, 0)

    # Compute the mean of the model
    mu = pred_levels + seasonal_repeat + pairwise

    # Sample the observations
    with numpyro.handlers.condition(data={"obs": y}):
        numpyro.sample("obs", dist.Normal(loc=mu, scale=scale))

We can now visualize the model structure.

numpyro.render_model(
    model=model,
    model_kwargs={"covariates": covariates_train, "y": y_train},
    render_distributions=True,
    render_params=True,
)

Prior Predictive Checks

As usual (highly recommended!), we should perform prior predictive checks.

prior_predictive = Predictive(model=model, num_samples=500, return_sites=["obs"])

rng_key, rng_subkey = random.split(rng_key)

prior_samples = prior_predictive(rng_subkey, covariates_train)

idata_prior = az.from_dict(
    prior_predictive={k: v[None, ...] for k, v in prior_samples.items()},
    coords={
        "time_train": time_train,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time_train"]},
)

Let’s plot the prior predictive distribution for the first \(8\) stations for the destination station ANTC.

station = "ANTC"
idx = dataset["stations"].index(station)

fig, axes = plt.subplots(
    nrows=8, ncols=1, figsize=(15, 18), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
    for j, hdi_prob in enumerate([0.94, 0.5]):
        az.plot_hdi(
            time_train[time_train >= T1 - 3 * (24 * 7)],
            idata_prior["prior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, time_train >= T1 - 3 * (24 * 7)]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C0",
            fill_kwargs={
                "alpha": 0.3 + 0.2 * j,
                "label": f"{hdi_prob*100:.0f}% HDI (train)",
            },
            smooth=False,
            ax=ax,
        )

    ax.plot(
        time_train[time_train >= T1 - 3 * (24 * 7)],
        data[i, idx, T1 - 3 * (24 * 7) : T1],
        "black",
        lw=1,
        label="Truth",
    )

    ax.legend(
        bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0, fontsize=12
    )

fig.suptitle("Prior predictive checks", fontsize=18, fontweight="bold");

Overall, the prior ranges look very reasonable (even too wide).

Inference with SVI

We now fit the model to the data using stochastic variational inference. This time the model runs for longer as compared to the first one (\(45\) seconds to \(3.5\) minutes).

%%time

guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
num_steps = 10_000

rng_key, rng_subkey = random.split(key=rng_key)

svi_result = svi.run(
    rng_subkey,
    num_steps,
    covariates_train,
    y_train,
)

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set_yscale("log")
ax.set_title("ELBO loss", fontsize=18, fontweight="bold");
100%|██████████| 10000/10000 [03:23<00:00, 49.10it/s, init loss: 1727067648.0000, avg. loss [9501-10000]: 3742141.6960]


CPU times: user 9min 45s, sys: 1min 46s, total: 11min 31s
Wall time: 3min 28s

The resulting ELBO loss good!

Posterior Predictive Check

Next, we generate posterior predictive samples for the forecast for each of the stations pairs.

posterior = Predictive(
    model=model,
    guide=guide,
    params=svi_result.params,
    num_samples=200,
    return_sites=["obs"],
)
rng_key, rng_subkey = random.split(rng_key)

idata_train = az.from_dict(
    posterior_predictive={
        k: v[None, ...] for k, v in posterior(rng_subkey, covariates_train).items()
    },
    coords={
        "time_train": time_train,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time_train"]},
)

idata_test = az.from_dict(
    posterior_predictive={
        k: v[None, ...] for k, v in posterior(rng_subkey, covariates).items()
    },
    coords={
        "time": time,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time"]},
)

To evaluate the model performance,we compute the CRPS for the training and test data. For comparison purposes, we clip the data to ensure the predictions are non-negative.

@jax.jit
def crps(
    truth: Float[Array, "n_series n_series t_max"],
    pred: Float[Array, "n_samples n_series n_series t_max"],
    sample_weight: Float[Array, " t_max"] | None = None,
) -> Float[Array, ""]:
    if pred.shape[1:] != (1,) * (pred.ndim - truth.ndim - 1) + truth.shape:
        raise ValueError(
            f"""Expected pred to have one extra sample dim on left.
            Actual shapes: {pred.shape} versus {truth.shape}"""
        )

    absolute_error = jnp.mean(jnp.abs(pred - truth), axis=0)

    num_samples = pred.shape[0]
    if num_samples == 1:
        return jnp.average(absolute_error, weights=sample_weight)

    pred = jnp.sort(pred, axis=0)
    diff = pred[1:] - pred[:-1]
    weight = jnp.arange(1, num_samples) * jnp.arange(num_samples - 1, 0, -1)
    weight = weight.reshape(weight.shape + (1,) * (diff.ndim - 1))

    per_obs_crps = absolute_error - jnp.sum(diff * weight, axis=0) / num_samples**2
    return jnp.average(per_obs_crps, weights=sample_weight)


crps_train = crps(
    y_train,
    jnp.array(idata_train["posterior_predictive"]["obs"].sel(chain=0).clip(min=0)),
)

crps_test = crps(
    y_test,
    jnp.array(
        idata_test["posterior_predictive"]["obs"]
        .sel(chain=0)
        .sel(time=slice(T1, T2))
        .clip(min=0)
    ),
)

Finally, we reproduce the model fit and plot from the Pyro example.

station = "ANTC"
idx = dataset["stations"].index(station)

fig, axes = plt.subplots(
    nrows=8, ncols=1, figsize=(15, 18), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
    for j, hdi_prob in enumerate([0.94, 0.5]):
        az.plot_hdi(
            time_train[time_train >= T1 - 24 * 7],
            idata_train["posterior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, time_train >= T1 - 24 * 7]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C0",
            fill_kwargs={
                "alpha": 0.3 + 0.2 * j,
                "label": f"{hdi_prob*100:.0f}% HDI (train)",
            },
            smooth=False,
            ax=ax,
        )

        az.plot_hdi(
            time[time >= T1],
            idata_test["posterior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, time >= T1]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C1",
            fill_kwargs={
                "alpha": 0.2 + 0.2 * j,
                "label": f"{hdi_prob*100:.0f}% HDI (test)",
            },
            smooth=False,
            ax=ax,
        )

    ax.plot(
        time[time >= T1 - 24 * 7],
        data[i, idx, T1 - 24 * 7 : T2],
        "black",
        lw=1,
        label="Truth",
    )

    ax.axvline(T1, color="C3", linestyle="--", label="Train/test split")

    ax.legend(
        bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0, fontsize=12
    )

fig.suptitle(
    f"""Posterior predictive checks

    Train CRPS: {crps_train:.4f} | Test CRPS: {crps_test:.4f}
    """,
    fontsize=18,
    fontweight="bold",
);