This are some notes on how to implement an ARMA(1, 1) model using NumPyro
for time series forecasting. The ARMA(1, 1) model is given by
\[y_t = \mu + \phi y_{t-1} + \theta \varepsilon_{t-1} + \varepsilon_t\]
where \(y_t\) is the time series, \(\mu\) is the mean, \(\phi\) is the autoregressive parameter, \(\theta\) is the moving average parameter, and \(\varepsilon_t\) is a white noise process with mean zero and variance \(\sigma^2\).
We work ot an end-to-end example: from data generation to model fitting and forecasting. We consolidate this example following the Example: AR(2) process from the documentation and the discussion in the forum thread Lax.scan to implement ARMA(1, 1). In addition, we use some code from some previous notes exponential smoothing.
Prepare Notebook
from collections.abc import Callable
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import lax, random
from jaxlib.xla_extension import ArrayImpl
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS, Predictive
from pydantic import BaseModel, Field
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.arima.model import ARIMA
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=42)
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generate Data
Whenever getting started with time series models (specially the ones with autoregressive and moving average components), it is a good idea to generate some data to understand how the model works. In addition, we can verify our models by trying to recover the parameters. We generate a time series with \(100\) observations using the ARMA(1, 1) model. We set the parameters \(\mu = 0.0\), \(\phi = 0.4\), \(\theta = 0.7\), and \(\sigma = 0.5\).
n_samples = 100 + 1
mu = 0.0
phi = 0.4
theta = 0.7
noise_scale = 0.5
For-Loop Approach
To start with simple, we can generate the time series using a for-loop. This is a very transparent way to understand the process.
def generate_arma_1_1_data_for_loop(
rng_key: ArrayImpl, n_samples: int, phi: float, theta: float, noise_scale: float
) -> ArrayImpl:
# Generate noise
error = noise_scale * random.normal(rng_key, (n_samples,))
# Preallocate array
y = jnp.zeros(n_samples)
# Data generating process
for t in range(1, y.size):
ar_part = phi * y[t - 1]
ma_part = theta * error[t - 1]
y_t = ar_part + ma_part + error[t]
y = y.at[t].set(y_t) # noqa
# Return the time series without the first element.
return y[1:]
y_foor_loop = generate_arma_1_1_data_for_loop(
rng_key=rng_key, n_samples=n_samples, phi=phi, theta=theta, noise_scale=noise_scale
)
Scan Approach
It is also instructive to generate the same time series using lax.scan
. This is a more efficient way to generate the time series, and it is also the way to go when implementing the model in NumPyro
. For an introduction to scan you can take a look into the previous post Notes on Exponential Smoothing with NumPyro.
def generate_arma_1_1_data_scan(
rng_key: ArrayImpl, n_samples: int, phi: float, theta: float, noise_scale: float
) -> ArrayImpl:
# Generate noise
error = noise_scale * random.normal(rng_key, (n_samples,))
# Data generating process step
def arma_step(carry, noise):
y_prev, error_prev = carry
ar_part = phi * y_prev
ma_part = theta * error_prev
y_t = ar_part + ma_part + noise
return (y_t, noise), y_t
# Scan over the time series
init_carry = (0.0, error[0])
_, y = lax.scan(arma_step, init_carry, error[1:])
return y
y_scan = generate_arma_1_1_data_scan(
rng_key=rng_key, n_samples=n_samples, phi=phi, theta=theta, noise_scale=noise_scale
)
We can verify we get the same results (as we have used the same seed for the random number generator).
jnp.allclose(y_foor_loop, y_scan, atol=1e-7)
Array(True, dtype=bool)
We can now visualize the time series.
y = y_scan
t = jnp.arange(y.size)
fix, ax = plt.subplots()
ax.plot(t, y)
ax.set(xlabel="time", ylabel="y", title="Time Series Data")
We can plot the (partial) autocorrelation function to verify the time series has the expected structure.
fig, ax = plt.subplots(
nrows=2, ncols=1, sharex=True, sharey=True, figsize=(12, 8), layout="constrained"
)
_ = plot_acf(y, lags=12, ax=ax[0])
_ = plot_pacf(y, lags=12, ax=ax[1])
Train-Test Split
As usual, lets do a train-test split.
n = y.size
prop_train = 0.9
n_train = round(prop_train * n)
y_train = y[:n_train]
t_train = t[:n_train]
y_test = y[n_train:]
t_test = t[n_train:]
fig, ax = plt.subplots()
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Time Series Data Split")
ARMA Model: Statsmodels
Before implementing the model in NumPyro
, we can use the statsmodels
library to fit the ARMA(1, 1) model. This is a good way to verify our implementation.
model = ARIMA(np.asarray(y_train), order=(1, 0, 1))
result = model.fit()
result.summary()
We see that we get the same parameters as we used to generate the data.
We now use the fitted model to forecast on the test set.
forecast_94_df = result.get_forecast(steps=y_test.size).summary_frame(alpha=0.06)
forecast_50_df = result.get_forecast(steps=y_test.size).summary_frame(alpha=0.5)
fig, ax = plt.subplots()
ax.fill_between(
t_test,
forecast_94_df["mean_ci_lower"],
forecast_94_df["mean_ci_upper"],
color="C2",
alpha=0.3,
label=r"$94\%$ CI (mean) [statsmodels]",
)
ax.fill_between(
t_test,
forecast_50_df["mean_ci_lower"],
forecast_50_df["mean_ci_upper"],
color="C2",
alpha=0.5,
label=r"$50\%$ CI (mean) [statsmodels]",
)
ax.plot(t_test, forecast_94_df["mean"], color="C2", label="mean forecast [statsmodels]")
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="ARMA(1, 1) Model Forecast (Statsmodels)")
NumPyro ARMA(1, 1) Model
Now we are ready to implement the ARMA(1, 1) model in NumPyro
.
Model Specification
As mentioned in the introduction, we blend the strategies presented in the Example: AR(2) process from the documentation and the discussion in the forum thread Lax.scan to implement ARMA(1, 1). Overall, the strategy is similar as in the exponential smoothing case: set priors, write the transition function and pass it trough scan
. Still, there are two major differences (at least for my implementation):
We do not condition on the data
y
but on the errors after passing them through scan. I learned this the hard was as the moving average parameter \(\theta\) was unchanged from the prior when conditioning ony
. This is simply because in thescan
step the termy[t] - pred
would vanish 🤦.Inside the model I had to write a custom prediction loop (via scan). The reason is that we want to recursively pass previous predictions for the autoregressive term. In addition, we need to set the errors to zero at prediction time (ecxept for the last one which we use the latest observed error). For a detailed discussion on the prediction step, please see the online chapter Ch 9.8: ARIMA Models; Forecasting.
def arma_1_1(y: ArrayImpl, future: int = 0) -> None:
t_max = y.size
mu = numpyro.sample("mu", dist.Normal(loc=0, scale=1))
phi = numpyro.sample("phi", dist.Uniform(low=-1, high=1))
theta = numpyro.sample("theta", dist.Uniform(low=-1, high=1))
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=1))
def transition_fn(carry, t):
y_prev, error_prev = carry
ar_part = phi * y_prev
ma_part = theta * error_prev
pred = mu + ar_part + ma_part
error = y[t] - pred
return (y[t], error), error
error_0 = y[0] - (mu + phi * mu)
_, errors = scan(transition_fn, (y[0], error_0), jnp.arange(1, t_max + future))
errors = jnp.concat([error_0[None], errors])
numpyro.sample("errors", dist.Normal(loc=0, scale=sigma), obs=errors)
if future > 0:
def prediction_fn(carry, _):
y_prev, error_prev = carry
ar_part = phi * y_prev
ma_part = theta * error_prev
pred = mu + ar_part + ma_part
error = 0.0
return (pred, error), pred
_, y_forecast = scan(prediction_fn, (y[-1], errors[-1]), jnp.arange(future))
numpyro.sample("y_forecast", dist.Normal(loc=y_forecast, scale=sigma))
Inference
We are now ready to fit the model using NumPyro
. We use the NUTS
sampler.
class InferenceParams(BaseModel):
num_warmup: int = Field(2_000, ge=1)
num_samples: int = Field(2_000, ge=1)
num_chains: int = Field(4, ge=1)
def run_inference(
rng_key: ArrayImpl,
model: Callable,
args: InferenceParams,
*model_args,
**nuts_kwargs,
) -> MCMC:
sampler = NUTS(model, **nuts_kwargs)
mcmc = MCMC(
sampler=sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
)
mcmc.run(rng_key, *model_args)
return mcmc
inference_params = InferenceParams()
rng_key, rng_subkey = random.split(key=rng_key)
mcmc = run_inference(rng_subkey, arma_1_1, inference_params, y_train)
Let’s look into the diagnostics and estimated parameters.
idata = az.from_numpyro(posterior=mcmc)
az.summary(data=idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
mu | -0.089 | 0.089 | -0.266 | 0.071 | 0.001 | 0.001 | 6845.0 | 5324.0 | 1.0 |
phi | 0.362 | 0.115 | 0.161 | 0.590 | 0.001 | 0.001 | 6520.0 | 5058.0 | 1.0 |
sigma | 0.498 | 0.038 | 0.429 | 0.569 | 0.000 | 0.000 | 7348.0 | 5868.0 | 1.0 |
theta | 0.678 | 0.083 | 0.527 | 0.831 | 0.001 | 0.001 | 6446.0 | 6010.0 | 1.0 |
print(f"""Divergences: {idata["sample_stats"]["diverging"].sum().item()}""")
Divergences: 0
axes = az.plot_trace(
data=idata,
compact=True,
lines=[
("mu", {}, mu),
("phi", {}, phi),
("theta", {}, theta),
("sigma", {}, noise_scale),
],
backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Trace", fontsize=16)
the model diagnostics look fine. Also, note that wee have successfully recovered the parameters used to generate the data 🚀!
Forecasting
We are now ready to forecast on the test set.
def forecast(
rng_key: ArrayImpl, model: Callable, samples: dict[str, ArrayImpl], *model_args
) -> dict[str, ArrayImpl]:
predictive = Predictive(
model=model,
posterior_samples=samples,
return_sites=["y_forecast", "errors"],
)
return predictive(rng_key, *model_args)
rng_key, rng_subkey = random.split(key=rng_key)
forecast = forecast(rng_subkey, arma_1_1, mcmc.get_samples(), y_train, y_test.size)
posterior_predictive = az.from_numpyro(
posterior_predictive=forecast,
coords={"t_test": t_test, "t": t},
dims={"y_forecast": ["t_test"], "errors": ["t"]},
)
We can now visualize the forecast at various credible intervals.
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C3",
smooth=False,
fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C3",
smooth=False,
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C3",
label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="ARMA(1, 1) Model Forecast")
The results look good and very comparable to the ones obtained using the statsmodels
library:
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C3",
smooth=False,
fill_kwargs={"alpha": 0.3, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C3",
smooth=False,
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C3",
label="mean forecast",
)
ax.plot(t_test, forecast_94_df["mean"], color="C2", label="mean forecast [statsmodels]")
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=3)
ax.set(xlabel="time", ylabel="y", title="ARMA(1, 1) Model Forecast - Model Comparison")
Using these learnings, it should not be hard to implement a general ARIMA(p, d, q) model. One has to be careful vectorizing the model to make it efficient.