In this notebook we provide a NumPyro implementation of the first model presented in the Pyro forecasting documentation: Forecasting III: hierarchical models. This model generalizes the local level model with seasonality presented in the univariate example Forecasting I: univariate, heavy tailed (see From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series for the corresponding NumPyro implementation).
In this example, we continue working with the BART train ridership dataset.
Prepare Notebook
import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import torch
from jax import random
from jaxtyping import Array, Float
from numpyro.contrib.control_flow import scan
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.reparam import LocScaleReparam
from pyro.contrib.examples.bart import load_bart_od
from pyro.ops.tensor_utils import periodic_repeat
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=42)
%load_ext autoreload
%autoreload 2
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype
%config InlineBackend.figure_format = "retina"
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
The jaxtyping extension is already loaded. To reload it, use:
%reload_ext jaxtyping
Read Data
dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))
dict_keys(['stations', 'start_date', 'counts'])
torch.Size([78888, 50, 50])
12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK
For this first model, we just model the rides to Embarcadero station, from each of the other \(50\) stations.
T = dataset["counts"].shape[0]
data = dataset["counts"][:, :, dataset["stations"].index("EMBR")].log1p()
print(data.shape)
torch.Size([78888, 50])
Letβs try to visualize these time series.
fig, ax = plt.subplots()
ax.plot(data[-24 * 7 * 2 :], "-", c="C0", alpha=0.1, markeredgewidth=0)
ax.set(
title="Hourly arrivals to EMBR for two weeks",
ylabel="log1p(# rides)",
xlabel="Hour after 2011-01-01",
);
Train - Test Split
For training purposes we will use data from 90 days before the test data.
T2 = data.size(-2) # end
T1 = T2 - 24 * 7 * 2 # train/test split
T0 = T1 - 24 * 90 # beginning: train on 90 days of data
y = jnp.array(data[T0:T2])
y_train = jnp.array(data[T0:T1])
y_test = jnp.array(data[T1:T2])
print(f"y: {y.shape}")
print(f"y_train: {y_train.shape}")
print(f"y_test: {y_test.shape}")
y: (2496, 50)
y_train: (2160, 50)
y_test: (336, 50)
n_stations = y_train.shape[-1]
time = jnp.array(range(T0, T2))
time_train = jnp.array(range(T0, T1))
t_max_train = time_train.size
time_test = jnp.array(range(T1, T2))
t_max_test = time_test.size
assert time_train.size + time_test.size == time.size
assert y_train.shape == (t_max_train, n_stations)
assert y_test.shape == (t_max_test, n_stations)
As in the example before (From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series), we use the covariates input tensor to encode the data size. We can of course use this tensor to encode other covariates, but for this example we will not use them.
covariates = jnp.zeros_like(y)
covariates_train = jnp.zeros_like(y_train)
covariates_test = jnp.zeros_like(y_test)
Repeating Seasonal Features
In the univariate case example we studied a very handy function to generate Fourier modes, periodic_features
. In this case, there is a very handy function to repeat the seasonal features, periodic_repeat
. It has two main parameters:
size
(int) β Desired size of the result along dimensiondim
.dim
(int) β The tensor dimension along which to repeat.
Letβs see some example from the docstrings.
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
periodic_repeat(x, size=4, dim=0)
tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
periodic_repeat(x, size=4, dim=1)
tensor([[1, 2, 3, 1],
[4, 5, 6, 4]])
The translation from PyTorch to JAX is not that hard (thank you GitHub Copilot π ).
def periodic_repeat_jax(tensor: Array, size: int, dim: int) -> Array:
"""
Repeat a period-sized tensor up to given size using JAX.
Parameters
----------
tensor : Array
A JAX array to be repeated.
size : int
Desired size of the result along dimension `dim`.
dim : int
The tensor dimension along which to repeat.
Returns
-------
Array
The repeated tensor.
References
----------
https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat
"""
assert isinstance(size, int) and size >= 0
assert isinstance(dim, int)
if dim >= 0:
dim -= tensor.ndim
period = tensor.shape[dim]
repeats = [1] * tensor.ndim
repeats[dim] = (size + period - 1) // period
result = jnp.tile(tensor, repeats)
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(None, size)
return result[tuple(slices)]
Letβs verify that the function works as expected for some examples.
assert jnp.allclose(
periodic_repeat_jax(jnp.array(x), 4, 0), jnp.array(periodic_repeat(x, 4, 0))
)
assert jnp.allclose(
periodic_repeat_jax(jnp.array(x), 4, 1), jnp.array(periodic_repeat(x, 4, 1))
)
assert jnp.allclose(
periodic_repeat_jax(jnp.array(x), 50, 1), jnp.array(periodic_repeat(x, 50, 1))
)
Model Specification
This first hierarchical model extends the local level model with seasonality seen in the univariate case example, From Pyro to NumPyro: Forecasting a univariate, heavy tailed time series.
def model(
covariates: Float[Array, "t_max n_series"],
y: Float[Array, "t_max n_series"] | None = None,
) -> None:
# Get the time and feature dimensions
t_max, n_series = covariates.shape
# Global scale for the drift
drift_scale = numpyro.sample("drift_scale", dist.LogNormal(loc=-20, scale=5))
# Scale for the Normal likelihood
sigma = numpyro.sample("sigma", dist.LogNormal(loc=-5, scale=5))
# Sample the centered parameter for the LocScaleReparam
centered = numpyro.sample("centered", dist.Uniform(low=0, high=1))
with numpyro.plate("n_series", n_series):
with (
numpyro.plate("time", t_max),
numpyro.handlers.reparam(
config={"drift": LocScaleReparam(centered=centered)}
),
):
# Sample the drift parameters
# We have one drift parameter per time series (station) and time point
drift = numpyro.sample("drift", dist.Normal(loc=0, scale=drift_scale))
with numpyro.plate("hour_of_week", 24 * 7):
# Sample the seasonal parameters
# We have one seasonal parameter per hour of the week and per station
seasonal = numpyro.sample("seasonal", dist.Normal(loc=0, scale=5))
# Repeat the seasonal parameters to match the length of the time series
seasonal_repeat = periodic_repeat_jax(seasonal, t_max, 0)
# Define the local level transition function
def transition_fn(carry, t):
"Local level transition function"
previous_level = carry
current_level = previous_level + drift[t]
return current_level, current_level
# Compute the latent levels using scan
_, pred_levels = scan(
transition_fn, init=jnp.zeros((n_series,)), xs=jnp.arange(t_max)
)
# Compute the mean of the model
mu = pred_levels + seasonal_repeat
# Sample the observations
with numpyro.handlers.condition(data={"obs": y}):
numpyro.sample("obs", dist.Normal(loc=mu, scale=sigma))
We can now visualize the model structure.
numpyro.render_model(
model=model,
model_kwargs={"covariates": covariates_train, "y": y_train},
render_distributions=True,
render_params=True,
)
Prior Predictive Checks
As usual (highly recommended!), we should perform prior predictive checks.
prior_predictive = Predictive(model=model, num_samples=2_000, return_sites=["obs"])
rng_key, rng_subkey = random.split(rng_key)
prior_samples = prior_predictive(rng_subkey, covariates_train)
idata_prior = az.from_dict(
prior_predictive={k: v[None, ...] for k, v in prior_samples.items()},
coords={"time_train": time_train, "n_series": jnp.arange(n_stations)},
dims={"obs": ["time_train", "n_series"]},
)
Letβs plot the prior predictive distribution for the first \(8\) stations.
fig, axes = plt.subplots(
nrows=8, ncols=1, figsize=(15, 18), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
for j, hdi_prob in enumerate([0.94, 0.5]):
az.plot_hdi(
time_train[time_train >= T1 - 3 * (24 * 7)],
idata_prior["prior_predictive"]["obs"].sel(n_series=i)[
:, :, time_train >= T1 - 3 * (24 * 7)
],
hdi_prob=hdi_prob,
color="C0",
fill_kwargs={
"alpha": 0.3 + 0.2 * j,
"label": f"{hdi_prob*100:.0f}% HDI (train)",
},
smooth=False,
ax=ax,
)
ax.plot(
time_train[time_train >= T1 - 3 * (24 * 7)],
data[T1 - 3 * (24 * 7) : T1, i],
"black",
lw=1,
label="Truth",
)
ax.axvline(T1, color="C3", linestyle="--", label="Train/test split")
ax.legend(
bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0, fontsize=12
)
fig.suptitle("Prior predictive checks", fontsize=18, fontweight="bold");
Overall, the prior ranges look very reasonable.
Inference with SVI
We now fit the model to the data using stochastic variational inference.
%%time
guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.05)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
num_steps = 15_000
rng_key, rng_subkey = random.split(key=rng_key)
svi_result = svi.run(
rng_subkey,
num_steps,
covariates_train,
y_train,
)
fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set_yscale("log")
ax.set_title("ELBO loss", fontsize=18, fontweight="bold");
100%|ββββββββββ| 15000/15000 [00:42<00:00, 356.57it/s, init loss: 81283016.0000, avg. loss [14251-15000]: 99854.5345]
CPU times: user 1min 19s, sys: 35.9 s, total: 1min 54s
Wall time: 44.6 s
The resulting ELBO loss good!
Posterior Predictive Check
Next, we generate posterior predictive samples for the forecast for each of the \(50\) stations.
posterior = Predictive(
model=model,
guide=guide,
params=svi_result.params,
num_samples=1_500,
return_sites=["obs"],
)
rng_key, rng_subkey = random.split(rng_key)
idata_train = az.from_dict(
posterior_predictive={
k: v[None, ...] for k, v in posterior(rng_subkey, covariates_train).items()
},
coords={"time_train": time_train, "n_series": jnp.arange(n_stations)},
dims={"obs": ["time_train", "n_series"]},
)
idata_test = az.from_dict(
posterior_predictive={
k: v[None, ...] for k, v in posterior(rng_subkey, covariates).items()
},
coords={"time": time, "n_series": jnp.arange(n_stations)},
dims={"obs": ["time", "n_series"]},
)
As in the univariate case example, we compute the CRPS for the training and test data.
@jax.jit
def crps(
truth: Float[Array, "t_max n_series"],
pred: Float[Array, "n_samples t_max n_series"],
sample_weight: Float[Array, " t_max"] | None = None,
) -> Float[Array, ""]:
if pred.shape[1:] != (1,) * (pred.ndim - truth.ndim - 1) + truth.shape:
raise ValueError(
f"""Expected pred to have one extra sample dim on left.
Actual shapes: {pred.shape} versus {truth.shape}"""
)
absolute_error = jnp.mean(jnp.abs(pred - truth), axis=0)
num_samples = pred.shape[0]
if num_samples == 1:
return jnp.average(absolute_error, weights=sample_weight)
pred = jnp.sort(pred, axis=0)
diff = pred[1:] - pred[:-1]
weight = jnp.arange(1, num_samples) * jnp.arange(num_samples - 1, 0, -1)
weight = weight.reshape(weight.shape + (1,) * (diff.ndim - 1))
per_obs_crps = absolute_error - jnp.sum(diff * weight, axis=0) / num_samples**2
return jnp.average(per_obs_crps, weights=sample_weight)
# For the purposes of comparison, we clip the predictions to be non-negative.
# But we keep the original values in the idata object.
crps_train = crps(
y_train,
jnp.array(idata_train["posterior_predictive"]["obs"].sel(chain=0)).clip(min=0),
)
crps_test = crps(
y_test,
jnp.array(
idata_test["posterior_predictive"]["obs"]
.sel(chain=0)
.sel(time=slice(T1, T2))
.clip(min=0)
),
)
Finally, we reproduce the model fit and plot from the Pyro example.
christmas_index = 78736
fig, axes = plt.subplots(
nrows=8, ncols=1, figsize=(15, 18), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
for j, hdi_prob in enumerate([0.94, 0.5]):
# For the purposes of comparison, we clip the predictions to be non-negative.
# But we keep the original values in the idata object.
az.plot_hdi(
time_train[time_train >= T1 - 24 * 7],
idata_train["posterior_predictive"]["obs"]
.sel(n_series=i)[:, :, time_train >= T1 - 24 * 7]
.clip(min=0),
hdi_prob=hdi_prob,
color="C0",
fill_kwargs={
"alpha": 0.3 + 0.2 * j,
"label": f"{hdi_prob*100:.0f}% HDI (train)",
},
smooth=False,
ax=ax,
)
az.plot_hdi(
time[time >= T1],
idata_test["posterior_predictive"]["obs"]
.sel(n_series=i)[:, :, time >= T1]
.clip(min=0),
hdi_prob=hdi_prob,
color="C1",
fill_kwargs={
"alpha": 0.2 + 0.2 * j,
"label": f"{hdi_prob*100:.0f}% HDI (test)",
},
smooth=False,
ax=ax,
)
ax.axvline(christmas_index, color="C2", lw=20, alpha=0.2, label="Christmas")
ax.plot(
time[time >= T1 - 24 * 7],
data[T1 - 24 * 7 : T2, i],
"black",
lw=1,
label="Truth",
)
ax.axvline(T1, color="C3", linestyle="--", label="Train/test split")
ax.legend(
bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0, fontsize=12
)
ax.text(
christmas_index,
-3,
"Christmas",
color="C2",
fontsize=10,
fontweight="bold",
horizontalalignment="center",
)
fig.suptitle(
f"""Posterior predictive checks
Train CRPS: {crps_train:.4f} | Test CRPS: {crps_test:.4f}
""",
fontsize=18,
fontweight="bold",
);
Observe that, as mentioned in the Pyro example, performs quite well except for the test data around Christmas. We can solve this by enlarging the size of the training data and by adding these special holidays as features (either dummies or Gaussian bump functions, see Seasonal Bump Functions).