This notebook serves as personal notes on NumPyro
’s implementation of the classic exponential smoothing forecasting method. I use Example: Holt-Winters Exponential Smoothing. The strategy is to go into the nitty-gritty details of the code presented in the example from the documentation: “Example: Holt-Winters Exponential Smoothing”. In particular, I want to understand the auto-regressive components using the scan
function, which always confuses me 😅. After reproducing the example from the documentation, we go a step further and extend the algorithm to include a damped trend.
These notes do not aim to give a complete introduction to exponential smoothing. Instead, we focus on the implementation using NumPyro. For a detailed and comprehensive introduction to the subject (and forecasting topics in general), please refer to the great online book “Forecasting: Principles and Practice” by Rob J Hyndman and George Athanasopoulos. In particular, see Chapter 8 for an introduction to exponential smoothing.
Prepare Notebook
from collections.abc import Callable
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import preliz as pz
from jax import random
from jax.lax import fori_loop
from jaxlib.xla_extension import ArrayImpl
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS, Predictive
from pydantic import BaseModel, Field
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=42)
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
The Scan Function
We start these notes by revisiting the scan
function as it is a key ingredient in the implementation of the exponential smoothing algorithm. The scan
function is a generalization of the for
loop. It is used to perform a computation that depends on the previous step’s output. This operator uses jax.lax.scan
to allow NumpPyro primitives like sample
and deterministic
to be used in a loop. The scan
function is used to implement the auto-regressive components of the exponential smoothing algorithm. To gain some intuition about this function, the jax.lax.scan
documentation provides a rough pure Pyhton implementation of the scan
function:
def scan(f, init, xs, length=None):
"""Pure Python implementation of scan.
Parameters
----------
f : A a Python function to be scanned.
init : An initial loop carry value
xs : The value over which to scan along the leading axis.
length : Optional integer specifying the number of loop iterations,
which must agree with the sizes of leading axes of the arrays in
xs (but can be used to perform scans where no input xs are needed).
"""
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Whenever I read this I get a bit confused. So let’s try to understand the scan
function by implementing a simple example.
Example: Sum of Powers
We will implement a simple sum of powers. We want to calculate the sum of the first h
powers of a number phi
.
\[ \phi \longmapsto \phi^1 + \phi^2 + \ldots + \phi^h \]
We can do this using a for
loop. However, we want to use the scan
function to get a better understanding of how it works.
First, we implement the sum of powers using a for
loop.
def sum_of_powers_for_loop(phi: float, h: int) -> float:
return sum(phi**i for i in range(1, h + 1))
assert sum_of_powers_for_loop(2, 0) == 0
assert sum_of_powers_for_loop(2, 1) == 2
assert sum_of_powers_for_loop(2, 2) == 2 + 2**2
assert sum_of_powers_for_loop(2, 3) == 2 + 2**2 + 2**3
Now, let’s look at the implementation using the scan
function.
def sum_of_powers_scan(phi, h):
def transition_fn(carry, phi):
power_sum, power = carry
power = power * phi
power_sum = power_sum + power
return (power_sum, power), power_sum
(power_sum, _), _ = scan(f=transition_fn, init=(0, 1), xs=jnp.ones(h) * phi)
return power_sum
We can verify that the two implementations give the same result.
assert sum_of_powers_scan(2, 0) == sum_of_powers_for_loop(2, 0)
assert sum_of_powers_scan(2, 1) == sum_of_powers_for_loop(2, 1)
assert sum_of_powers_scan(2, 2) == sum_of_powers_for_loop(2, 2)
assert sum_of_powers_scan(2, 3) == sum_of_powers_for_loop(2, 3)
assert sum_of_powers_scan(2, 10) == sum_of_powers_for_loop(2, 10)
There is another handy function to write efficient for loops in JAX: the fori_loop
function. From the documentation we also get a pure Python implementation:
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
In this case we can easily implement the sum of powers using the fori_loop
function.
def sum_of_powers_fori_loop(phi, h):
def body_fn(i, power_sum):
return power_sum + phi**i
return fori_loop(lower=1, upper=h + 1, body_fun=body_fn, init_val=0)
assert sum_of_powers_scan(2, 0) == sum_of_powers_for_loop(2, 0)
assert sum_of_powers_scan(2, 1) == sum_of_powers_for_loop(2, 1)
assert sum_of_powers_scan(2, 2) == sum_of_powers_for_loop(2, 2)
assert sum_of_powers_scan(2, 3) == sum_of_powers_for_loop(2, 3)
assert sum_of_powers_scan(2, 10) == sum_of_powers_for_loop(2, 10)
After this brief introduction to the scan
function, we are ready to dive into the implementation of the exponential smoothing algorithm using NumPyro.
Generate Synthetic Data
We start by generating some synthetic data. We will use a very similar data generation process as in the example from the documentation. We generate a time series with a trend and a seasonal component.
n_seasons = 15
t = jnp.linspace(start=0, stop=n_seasons + 1, num=(n_seasons + 1) * n_seasons)
y = jnp.cos(2 * jnp.pi * t) + jnp.log(t + 1) + 0.2 * random.normal(rng_key, t.shape)
fig, ax = plt.subplots()
ax.plot(t, y)
ax.set(xlabel="time", ylabel="y", title="Time Series Data")
Train - Test Split
We now split the data into a training and a test set. We will use the training set to fit the model and the test set to evaluate the model’s performance.
n = y.size
prop_train = 0.8
n_train = round(prop_train * n)
y_train = y[:n_train]
t_train = t[:n_train]
y_test = y[n_train:]
t_test = t[n_train:]
fig, ax = plt.subplots()
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Time Series Data Split")
Level Model
We do not look and the complete model from start. Instead, we progressively study and add components. We start with the most fundamental component of the model: the level model. The level model is a simple model that predicts the next value in the time series based on the previous value. The model is defined by the following equations:
\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t \\ l_t = & \: \alpha y_t + (1 - \alpha)l_{t-1} \end{align*}\]
Here:
- \(y_t\) is the observed value at time \(t\).
- \(\hat{y}_{t+h|t}\) is the forecast of the value at time \(t+h\) given the information up to time \(t\).
- \(l_t\) is the level at time \(t\).
- \(\alpha\) is the smoothing parameter. It is a value between 0 and 1. A value of 1 means that the forecast is based only on the observed value at time \(t\). A value of 0 means that the forecast is based only on the previous level.
Note that the level equation is a simple weighted average of the observed value and the previous level (for more details see 8.1 Simple exponential smoothing). Moreover, note that this equation is perfectly suited for the scan
function.
Model Specification
We start by specifying the level model in NumPyro. We use the same priors from the example in the documentation. Note that we place uniform priors on the smoothing parameter \(\alpha\).
fig, ax = plt.subplots()
pz.Beta(alpha=1, beta=1).plot_pdf(ax=ax)
ax.set(title="Beta(1, 1) Prior")
def level_model(y: ArrayImpl, future: int = 0) -> None:
# Get time series length
t_max = y.shape[0]
# --- Priors ---
## Level
level_smoothing = numpyro.sample(
"level_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
## Noise
noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
# --- Transition Function ---
def transition_fn(carry, t):
previous_level = carry
level = jnp.where(
t < t_max,
level_smoothing * y[t] + (1 - level_smoothing) * previous_level,
previous_level,
)
mu = previous_level
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))
return level, pred
# --- Run Scan ---
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
level_init,
jnp.arange(t_max + future),
)
# --- Forecast ---
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
Observe we are using the condition
effect handler to condition the model on the observed data. As explained in the example from the documentation “Time Series Forecasting”:
The reason is we also want to use this model for forecasting. In forecasting, future values of
y
are non-observable, soobs=y[t]
does not make sense whent >= len(y)
(caution: index out-of-bound errors do not get raised in JAX, e.g. jnp.arange(3)[10] == 2
). Usingcondition
, when the length ofscan
is larger than the length of the conditioned/observed site, unobserved values will be sampled from the distribution of that site.
Inference
We now proceed to use MCMC to fit the level model to the training data. We define some helper function as we will apply them to many models.
class InferenceParams(BaseModel):
num_warmup: int = Field(2_000, ge=1)
num_samples: int = Field(2_000, ge=1)
num_chains: int = Field(4, ge=1)
def run_inference(
rng_key: ArrayImpl,
model: Callable,
args: InferenceParams,
*model_args,
**nuts_kwargs,
) -> MCMC:
sampler = NUTS(model, **nuts_kwargs)
mcmc = MCMC(
sampler=sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
)
mcmc.run(rng_key, *model_args)
return mcmc
Let’s fit the model:
inference_params = InferenceParams()
rng_key, rng_subkey = random.split(key=rng_key)
level_mcmc = run_inference(rng_subkey, level_model, inference_params, y_train)
The diagnostics look good:
level_idata = az.from_numpyro(posterior=level_mcmc)
az.summary(data=level_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
level_init | 1.013 | 0.396 | 0.275 | 1.778 | 0.005 | 0.004 | 5411.0 | 4042.0 | 1.0 |
level_smoothing | 0.974 | 0.023 | 0.931 | 1.000 | 0.000 | 0.000 | 4779.0 | 3592.0 | 1.0 |
noise | 0.430 | 0.022 | 0.392 | 0.472 | 0.000 | 0.000 | 6447.0 | 5120.0 | 1.0 |
print(f"""Divergences: {level_idata["sample_stats"]["diverging"].sum().item()}""")
Divergences: 0
axes = az.plot_trace(
data=level_idata,
compact=True,
kind="rank_bars",
backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Level Model Trace", fontsize=16)
Forecast
Now we can use the fitted model to forecast the test set. As above, we define a helper function for this purpose:
def forecast(
rng_key: ArrayImpl, model: Callable, samples: dict[str, ArrayImpl], *model_args
) -> dict[str, ArrayImpl]:
predictive = Predictive(
model=model,
posterior_samples=samples,
return_sites=["y_forecast"],
)
return predictive(rng_key, *model_args)
rng_key, rng_subkey = random.split(key=rng_key)
level_forecast = forecast(
rng_subkey, level_model, level_mcmc.get_samples(), y_train, y_test.size
)
level_posterior_predictive = az.from_numpyro(
posterior_predictive=level_forecast,
coords={"t": t_test},
dims={"y_forecast": ["t"]},
)
We can now visualize the forecast:
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=level_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C2",
fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=level_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C2",
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
level_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C2",
label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Level Model Forecast")
As expected, the forecast is flat.
Level + Trend Model
Next, we add a trend component to the model. The trend model is defined by the following equations:
\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + hb_t \\ l_t = & \: \alpha y_t + (1 - \alpha)(l_{t-1} + b_{t-1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1-\beta^*)b_{t - 1} \end{align*}\]
Here \(b_t\) denotes the trend at time \(t\) and \(\beta^*\) is a smoothing parameter. The trend equation is a weighted average of the difference between the current level and the previous level and the previous trend.
Remark: The reason for the upper \(*\) in the \(\beta^*\) is just a result of the notation in the context of the state space representation of the model. See here.
Model Specification
Note that given the level model above, we can easily extend it to include the trend component:
def level_trend_model(y: ArrayImpl, future: int = 0) -> None:
# Get time series length
t_max = y.shape[0]
# --- Priors ---
## Level
level_smoothing = numpyro.sample(
"level_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
## Trend
trend_smoothing = numpyro.sample(
"trend_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
trend_init = numpyro.sample("trend_init", dist.Normal(loc=0, scale=1))
## Noise
noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
# --- Transition Function ---
def transition_fn(carry, t):
previous_level, previous_trend = carry
level = jnp.where(
t < t_max,
level_smoothing * y[t]
+ (1 - level_smoothing) * (previous_level + previous_trend),
previous_level,
)
trend = jnp.where(
t < t_max,
trend_smoothing * (level - previous_level)
+ (1 - trend_smoothing) * previous_trend,
previous_trend,
)
step = jnp.where(t < t_max, 1, t - t_max + 1)
mu = previous_level + step * previous_trend
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))
return (level, trend), pred
# --- Run Scan ---
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
(level_init, trend_init),
jnp.arange(t_max + future),
)
# --- Forecast ---
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
Inference
We fit the model as before:
rng_key, rng_subkey = random.split(key=rng_key)
level_trend_mcmc = run_inference(
rng_subkey, level_trend_model, inference_params, y_train
)
level_trend_idata = az.from_numpyro(posterior=level_trend_mcmc)
az.summary(data=level_trend_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
level_init | 1.224 | 0.378 | 0.518 | 1.928 | 0.005 | 0.004 | 4965.0 | 4681.0 | 1.0 |
level_smoothing | 0.641 | 0.047 | 0.560 | 0.733 | 0.001 | 0.001 | 4481.0 | 3110.0 | 1.0 |
noise | 0.430 | 0.023 | 0.388 | 0.473 | 0.000 | 0.000 | 5229.0 | 5336.0 | 1.0 |
trend_init | 0.011 | 0.344 | -0.619 | 0.664 | 0.005 | 0.004 | 4819.0 | 4917.0 | 1.0 |
trend_smoothing | 0.912 | 0.077 | 0.771 | 1.000 | 0.001 | 0.001 | 3466.0 | 3557.0 | 1.0 |
print(f"""Divergences: {level_trend_idata["sample_stats"]["diverging"].sum().item()}""")
Divergences: 0
axes = az.plot_trace(
data=level_trend_idata,
compact=True,
kind="rank_bars",
backend_kwargs={"figsize": (12, 9), "layout": "constrained"},
)
plt.gcf().suptitle("Level + Trend Model Trace", fontsize=16)
Forecast
We generate the forecast:
rng_key, rng_subkey = random.split(key=rng_key)
level_trend_forecast = forecast(
rng_subkey, level_trend_model, level_trend_mcmc.get_samples(), y_train, y_test.size
)
level_trend_posterior_predictive = az.from_numpyro(
posterior_predictive=level_trend_forecast,
coords={"t": t_test},
dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=level_trend_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C2",
fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=level_trend_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C2",
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
level_trend_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C2",
label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Level + Trend Model Forecast")
Ups! This is literally just extrapolating with s straight line 😒. Still, this is often quite a good forecasting baseline for short term horizons. We are clearly missing the seasonality component to complement the trend model. This is what we do next!
Level + Trend + Seasonality Model (Holt-Winters)
We have finally arrived at the (additive) level + trend + seasonality model, also known as Holt-Winters model. We extend the trend model to include a seasonal component:
\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + hb_t + s_{t + h - m(k + 1)} \\ l_t = & \: \alpha(y_t - s_{t - m}) + (1 - \alpha)(l_{t - 1} + b_{t - 1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1 - \beta^*)b_{t - 1} \\ s_t = & \: \gamma(y_t - l_{t - 1} - b_{t - 1})+(1 - \gamma)s_{t - m} \end{align*}\]
We have added the seasonal component \(s_t\) with the corresponding smoothing parameter \(\gamma\). Here \(m\) denotes the number of seasons. The parameter \(k\) is the integer part of \((h - 1)/m\) (this just takes the latest seasonality estimate for this time point). for example, in the case \((h - 1)/m\) is an integer then
\[ t + h - m(k + 1) = t + h - m(h - 1)/m = t + h - (h - 1) = t + 1 \]
Remark: Similar to the note on thee \(\beta^*\) notation, the parameter \(\gamma\) is often called an adjusted smoothing as it is of the form \(\gamma=\gamma^{*}(1 - \alpha)\), so that \(0 \leq \gamma^{*} \leq 1\) translates to \(0 \leq \gamma \leq 1 - \alpha\). This is just a result of the state space representation of the model. See here.
Model Specification
We now use the model from the example:
def holt_winters_model(y: ArrayImpl, n_seasons: int, future: int = 0) -> None:
# Get time series length
t_max = y.shape[0]
# --- Priors ---
## Level
level_smoothing = numpyro.sample(
"level_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
## Trend
trend_smoothing = numpyro.sample(
"trend_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
trend_init = numpyro.sample("trend_init", dist.Normal(loc=0, scale=1))
## Seasonality
seasonality_smoothing = numpyro.sample(
"seasonality_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)
with numpyro.plate("n_seasons", n_seasons):
seasonality_init = numpyro.sample(
"seasonality_init", dist.Normal(loc=0, scale=1)
)
## Noise
noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
# --- Transition Function ---
def transition_fn(carry, t):
previous_level, previous_trend, previous_seasonality = carry
level = jnp.where(
t < t_max,
level_smoothing * (y[t] - previous_seasonality[0])
+ (1 - level_smoothing) * (previous_level + previous_trend),
previous_level,
)
trend = jnp.where(
t < t_max,
trend_smoothing * (level - previous_level)
+ (1 - trend_smoothing) * previous_trend,
previous_trend,
)
new_season = jnp.where(
t < t_max,
adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend))
+ (1 - adj_seasonality_smoothing) * previous_seasonality[0],
previous_seasonality[0],
)
step = jnp.where(t < t_max, 1, t - t_max + 1)
mu = previous_level + step * previous_trend + previous_seasonality[0]
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))
seasonality = jnp.concatenate(
[previous_seasonality[1:], new_season[None]], axis=0
)
return (level, trend, seasonality), pred
# --- Run Scan ---
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
(level_init, trend_init, seasonality_init),
jnp.arange(t_max + future),
)
# --- Forecast ---
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
Remark [Seasonality Looping]: Please note that the new_season
variable in the transition_fn
simply computes the next seasonal step as in the mathematical formulation of the model. Observe how the previous season s_{t - m}
is given by previous_seasonality[0]
. The index zero looks weird but it makes sense since we are looping the seasonality (carry) below as
seasonality = jnp.concatenate(
[previous_seasonality[1:], new_season[None]], axis=0
)
so that the first item on such array always correspond to the desired previous seasonality in the iteration formula 💡.
Inference
Note that all of these models sample super fast!
rng_key, rng_subkey = random.split(key=rng_key)
holt_winters_mcmc = run_inference(
rng_subkey,
holt_winters_model,
inference_params,
y_train,
n_seasons,
)
holt_winters_idata = az.from_numpyro(posterior=holt_winters_mcmc)
az.summary(data=holt_winters_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
level_init | 0.162 | 0.278 | -0.359 | 0.681 | 0.008 | 0.006 | 1153.0 | 1703.0 | 1.0 |
level_smoothing | 0.117 | 0.050 | 0.024 | 0.209 | 0.001 | 0.001 | 2256.0 | 2089.0 | 1.0 |
noise | 0.256 | 0.014 | 0.231 | 0.282 | 0.000 | 0.000 | 3276.0 | 4369.0 | 1.0 |
seasonality_init[0] | 0.965 | 0.296 | 0.433 | 1.545 | 0.008 | 0.006 | 1317.0 | 1721.0 | 1.0 |
seasonality_init[1] | 0.908 | 0.300 | 0.372 | 1.496 | 0.008 | 0.006 | 1271.0 | 1967.0 | 1.0 |
seasonality_init[2] | 0.542 | 0.297 | -0.023 | 1.091 | 0.008 | 0.006 | 1299.0 | 1953.0 | 1.0 |
seasonality_init[3] | 0.258 | 0.296 | -0.282 | 0.818 | 0.008 | 0.006 | 1310.0 | 1900.0 | 1.0 |
seasonality_init[4] | 0.021 | 0.298 | -0.544 | 0.562 | 0.008 | 0.006 | 1317.0 | 1961.0 | 1.0 |
seasonality_init[5] | -0.661 | 0.297 | -1.246 | -0.126 | 0.008 | 0.006 | 1299.0 | 1983.0 | 1.0 |
seasonality_init[6] | -0.836 | 0.299 | -1.375 | -0.258 | 0.008 | 0.006 | 1305.0 | 1993.0 | 1.0 |
seasonality_init[7] | -0.989 | 0.295 | -1.554 | -0.440 | 0.008 | 0.006 | 1260.0 | 1995.0 | 1.0 |
seasonality_init[8] | -0.773 | 0.293 | -1.330 | -0.220 | 0.008 | 0.006 | 1324.0 | 2162.0 | 1.0 |
seasonality_init[9] | -0.824 | 0.295 | -1.370 | -0.266 | 0.008 | 0.006 | 1277.0 | 1950.0 | 1.0 |
seasonality_init[10] | -0.542 | 0.299 | -1.112 | 0.005 | 0.009 | 0.006 | 1208.0 | 1912.0 | 1.0 |
seasonality_init[11] | -0.091 | 0.299 | -0.642 | 0.475 | 0.008 | 0.006 | 1354.0 | 2243.0 | 1.0 |
seasonality_init[12] | 0.387 | 0.296 | -0.173 | 0.939 | 0.008 | 0.006 | 1309.0 | 2039.0 | 1.0 |
seasonality_init[13] | 0.663 | 0.297 | 0.111 | 1.225 | 0.008 | 0.006 | 1261.0 | 1827.0 | 1.0 |
seasonality_init[14] | 0.910 | 0.298 | 0.349 | 1.459 | 0.008 | 0.006 | 1311.0 | 1942.0 | 1.0 |
seasonality_smoothing | 0.299 | 0.082 | 0.145 | 0.454 | 0.002 | 0.001 | 2588.0 | 1663.0 | 1.0 |
trend_init | 0.029 | 0.012 | 0.007 | 0.053 | 0.000 | 0.000 | 2908.0 | 3843.0 | 1.0 |
trend_smoothing | 0.069 | 0.063 | 0.000 | 0.145 | 0.002 | 0.001 | 2153.0 | 2119.0 | 1.0 |
print(
f"""Divergences: {holt_winters_idata["sample_stats"]["diverging"].sum().item()}"""
)
Divergences: 0
axes = az.plot_trace(
data=holt_winters_idata,
compact=True,
kind="rank_bars",
backend_kwargs={"figsize": (12, 12), "layout": "constrained"},
)
plt.gcf().suptitle("Holt Winters Trace", fontsize=16)
The diagnostics and the trace plot look good!
Forecast
Let’s see if this model can forecast the test set better than the previous models:
rng_key, rng_subkey = random.split(key=rng_key)
holt_winters_forecast = forecast(
rng_subkey,
holt_winters_model,
holt_winters_mcmc.get_samples(),
y_train,
n_seasons,
y_test.size,
)
holt_winters_posterior_predictive = az.from_numpyro(
posterior_predictive=holt_winters_forecast,
coords={"t": t_test},
dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C2",
fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C2",
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C2",
label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Holt Winters Model Forecast")
Indeed! The forecast look pretty good 🚀! The point forecast is slighly over-forecasting, but still withing the credible range. Let’s see how we can make the trend forecast a bit more conservative 🤓.
Damped Holt-Winters
As a final model, we present an extension of the model above. We add a damping parameter \(\phi\) to the trend component. The damped Holt-Winters model is defined by the following equations:
\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + \phi_{h}b_t + s_{t + h - m(k + 1)} \\ l_t = & \: \alpha(y_t - s_{t - m}) + (1 - \alpha)(l_{t - 1} + \phi b_{t - 1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1 - \beta^*)\phi b_{t - 1} \\ s_t = & \: \gamma(y_t - l_{t - 1} - \phi b_{t - 1})+(1 - \gamma)s_{t - m} \end{align*}\]
where \(0 \leq \phi \leq 1\) and \(\phi_{h} := \phi + \phi^2 + \ldots + \phi^{h}\).
Remark: Now you see why we were interested in the power sum scan
example at the beginning of these notes 😅.
Model Specification
Substituting \(b_{t -1} \longmapsto \phi b_{t -1}\) is straightforward. What needs a bit more attention is the \(\phi_{h}\) term. The for loop implementation is a no=go. The can one actually does not work as the size of the xs
argument is dynamic. Heceen, the fori_loop
function is thee way to go. Moreover, in order to help the sampler, wee impose more informative priors on the smoothing parameters.
fig, ax = plt.subplots()
pz.Beta(alpha=1, beta=1).plot_pdf(ax=ax)
pz.Beta(alpha=2, beta=2).plot_pdf(ax=ax)
pz.Beta(alpha=2, beta=5).plot_pdf(ax=ax)
ax.set(title="Beta Priors")
def damped_holt_winters_model(y: ArrayImpl, n_seasons: int, future: int = 0) -> None:
# Get time series length
t_max = y.shape[0]
# --- Priors ---
## Level
level_smoothing = numpyro.sample("level_smoothing", dist.Beta(2, 2))
level_init = numpyro.sample("level_init", dist.Normal(0, 1))
## Trend
trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(2, 2))
trend_init = numpyro.sample("trend_init", dist.Normal(0.5, 1))
## Seasonality
seasonality_smoothing = numpyro.sample("seasonality_smoothing", dist.Beta(2, 2))
adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)
## Damping
phi = numpyro.sample("phi", dist.Beta(2, 5))
with numpyro.plate("n_seasons", n_seasons):
seasonality_init = numpyro.sample("seasonality_init", dist.Normal(0, 1))
## Noise
noise = numpyro.sample("noise", dist.HalfNormal(1))
# --- Transition Function ---
def transition_fn(carry, t):
previous_level, previous_trend, previous_seasonality = carry
level = jnp.where(
t < t_max,
level_smoothing * (y[t] - previous_seasonality[0])
+ (1 - level_smoothing) * (previous_level + phi * previous_trend),
previous_level,
)
trend = jnp.where(
t < t_max,
trend_smoothing * (level - previous_level)
+ (1 - trend_smoothing) * phi * previous_trend,
previous_trend,
)
new_season = jnp.where(
t < t_max,
adj_seasonality_smoothing * (y[t] - (previous_level + phi * previous_trend))
+ (1 - adj_seasonality_smoothing) * previous_seasonality[0],
previous_seasonality[0],
)
step = jnp.where(t < t_max, 1, t - t_max + 1)
phi_step = fori_loop(
lower=1, upper=step + 1, body_fun=lambda i, val: val + phi**i, init_val=0
)
mu = previous_level + phi_step * previous_trend + previous_seasonality[0]
pred = numpyro.sample("pred", dist.Normal(mu, noise))
seasonality = jnp.concatenate(
[previous_seasonality[1:], new_season[None]], axis=0
)
return (level, trend, seasonality), pred
# --- Run Scan ---
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
(level_init, trend_init, seasonality_init),
jnp.arange(t_max + future),
)
# --- Forecast ---
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
Inference
An important note is that the sampler needs the parameter forward_mode_differentiation
as we are using the fori_loop
function. This is because the fori_loop
function is not compatible with reverse mode differentiation. This is a bit of a bummer as reverse mode differentiation is often more efficient. However, the sampler still runs pretty fast.
rng_key, rng_subkey = random.split(key=rng_key)
damped_holt_winters_mcmc = run_inference(
rng_subkey,
damped_holt_winters_model,
inference_params,
y_train,
n_seasons,
target_accept_prob=0.95,
forward_mode_differentiation=True,
)
damped_holt_winters_idata = az.from_numpyro(posterior=damped_holt_winters_mcmc)
az.summary(data=damped_holt_winters_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
level_init | 0.141 | 0.375 | -0.575 | 0.836 | 0.011 | 0.008 | 1237.0 | 3017.0 | 1.01 |
level_smoothing | 0.224 | 0.050 | 0.134 | 0.320 | 0.001 | 0.001 | 3590.0 | 4847.0 | 1.00 |
noise | 0.262 | 0.014 | 0.235 | 0.288 | 0.000 | 0.000 | 4121.0 | 4769.0 | 1.00 |
phi | 0.241 | 0.132 | 0.010 | 0.471 | 0.002 | 0.001 | 3654.0 | 3599.0 | 1.00 |
seasonality_init[0] | 0.944 | 0.289 | 0.414 | 1.501 | 0.010 | 0.007 | 803.0 | 1878.0 | 1.01 |
seasonality_init[1] | 0.893 | 0.291 | 0.385 | 1.473 | 0.010 | 0.007 | 807.0 | 2189.0 | 1.01 |
seasonality_init[2] | 0.543 | 0.292 | 0.003 | 1.087 | 0.010 | 0.007 | 797.0 | 2162.0 | 1.01 |
seasonality_init[3] | 0.275 | 0.289 | -0.254 | 0.829 | 0.010 | 0.007 | 803.0 | 1939.0 | 1.01 |
seasonality_init[4] | 0.053 | 0.296 | -0.509 | 0.595 | 0.010 | 0.007 | 819.0 | 2140.0 | 1.01 |
seasonality_init[5] | -0.602 | 0.286 | -1.117 | -0.050 | 0.010 | 0.007 | 792.0 | 2124.0 | 1.01 |
seasonality_init[6] | -0.771 | 0.288 | -1.294 | -0.211 | 0.010 | 0.007 | 819.0 | 2157.0 | 1.01 |
seasonality_init[7] | -0.928 | 0.287 | -1.448 | -0.365 | 0.010 | 0.007 | 774.0 | 1975.0 | 1.01 |
seasonality_init[8] | -0.713 | 0.289 | -1.261 | -0.179 | 0.010 | 0.007 | 866.0 | 2105.0 | 1.01 |
seasonality_init[9] | -0.759 | 0.289 | -1.302 | -0.215 | 0.011 | 0.008 | 743.0 | 1861.0 | 1.01 |
seasonality_init[10] | -0.479 | 0.292 | -1.038 | 0.053 | 0.010 | 0.007 | 813.0 | 1992.0 | 1.01 |
seasonality_init[11] | -0.042 | 0.298 | -0.603 | 0.514 | 0.010 | 0.007 | 873.0 | 1890.0 | 1.01 |
seasonality_init[12] | 0.406 | 0.289 | -0.108 | 0.985 | 0.010 | 0.007 | 791.0 | 1959.0 | 1.01 |
seasonality_init[13] | 0.670 | 0.291 | 0.152 | 1.245 | 0.010 | 0.007 | 824.0 | 1971.0 | 1.01 |
seasonality_init[14] | 0.895 | 0.290 | 0.366 | 1.432 | 0.010 | 0.007 | 790.0 | 2190.0 | 1.01 |
seasonality_smoothing | 0.300 | 0.093 | 0.121 | 0.480 | 0.002 | 0.001 | 2617.0 | 1748.0 | 1.00 |
trend_init | 0.457 | 0.880 | -1.215 | 2.138 | 0.015 | 0.010 | 3642.0 | 4996.0 | 1.00 |
trend_smoothing | 0.481 | 0.222 | 0.079 | 0.864 | 0.003 | 0.002 | 5294.0 | 4586.0 | 1.00 |
print(
f"""Divergences: {
damped_holt_winters_idata["sample_stats"]["diverging"].sum().item()
}
"""
)
Divergences: 0
axes = az.plot_trace(
data=damped_holt_winters_idata,
compact=True,
kind="rank_bars",
backend_kwargs={"figsize": (12, 12), "layout": "constrained"},
)
plt.gcf().suptitle("Holt Winters Trace", fontsize=16)
We do not get any divergences 🙌!
Forecast
Finally, we generate the forecast for the damped Holt-Winters model:
rng_key, rng_subkey = random.split(key=rng_key)
damped_holt_winters_forecast = forecast(
rng_subkey,
damped_holt_winters_model,
holt_winters_mcmc.get_samples(),
y_train,
n_seasons,
y_test.size,
)
damped_holt_winters_posterior_predictive = az.from_numpyro(
posterior_predictive=damped_holt_winters_forecast,
coords={"t": t_test},
dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
x=t_test,
y=damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.94,
color="C2",
fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
ax=ax,
)
az.plot_hdi(
x=t_test,
y=damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
hdi_prob=0.50,
color="C2",
fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
ax=ax,
)
ax.plot(
t_test,
damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
dim=("chain", "draw")
),
color="C2",
label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Damped Holt Winters Model Forecast")
We indeed see a milder trend forecast as expected (note the posterior distribution of \(\phi\) is far from \(1\))!
This was a great learning exercise for me. It provided me intuition and practice to tackle more complex models like the one presented in the documentation: “Time Series Forecasting: Seasonal, Global Trend (SGT)”.
I hope you also found it useful!