20 min read

Notes on Exponential Smoothing with NumPyro

This notebook serves as personal notes on NumPyro’s implementation of the classic exponential smoothing forecasting method. I use Example: Holt-Winters Exponential Smoothing. The strategy is to go into the nitty-gritty details of the code presented in the example from the documentation: “Example: Holt-Winters Exponential Smoothing”. In particular, I want to understand the auto-regressive components using the scan function, which always confuses me 😅. After reproducing the example from the documentation, we go a step further and extend the algorithm to include a damped trend.

These notes do not aim to give a complete introduction to exponential smoothing. Instead, we focus on the implementation using NumPyro. For a detailed and comprehensive introduction to the subject (and forecasting topics in general), please refer to the great online book “Forecasting: Principles and Practice” by Rob J Hyndman and George Athanasopoulos. In particular, see Chapter 8 for an introduction to exponential smoothing.

Prepare Notebook

from collections.abc import Callable

import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import preliz as pz
from jax import random
from jax.lax import fori_loop
from jaxlib.xla_extension import ArrayImpl
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS, Predictive
from pydantic import BaseModel, Field

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

The Scan Function

We start these notes by revisiting the scan function as it is a key ingredient in the implementation of the exponential smoothing algorithm. The scan function is a generalization of the for loop. It is used to perform a computation that depends on the previous step’s output. This operator uses jax.lax.scan to allow NumpPyro primitives like sample and deterministic to be used in a loop. The scan function is used to implement the auto-regressive components of the exponential smoothing algorithm. To gain some intuition about this function, the jax.lax.scan documentation provides a rough pure Pyhton implementation of the scan function:

def scan(f, init, xs, length=None):
  """Pure Python implementation of scan.

  Parameters
  ----------
  f : A  a Python function to be scanned.
  init : An initial loop carry value
  xs : The value over which to scan along the leading axis.
  length :  Optional integer specifying the number of loop iterations,
      which must agree with the sizes of leading axes of the arrays in
      xs (but can be used to perform scans where no input xs are needed).
  """
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

Whenever I read this I get a bit confused. So let’s try to understand the scan function by implementing a simple example.

Example: Sum of Powers

We will implement a simple sum of powers. We want to calculate the sum of the first h powers of a number phi.

\[ \phi \longmapsto \phi^1 + \phi^2 + \ldots + \phi^h \]

We can do this using a for loop. However, we want to use the scan function to get a better understanding of how it works.

First, we implement the sum of powers using a for loop.

def sum_of_powers_for_loop(phi: float, h: int) -> float:
    return sum(phi**i for i in range(1, h + 1))


assert sum_of_powers_for_loop(2, 0) == 0
assert sum_of_powers_for_loop(2, 1) == 2
assert sum_of_powers_for_loop(2, 2) == 2 + 2**2
assert sum_of_powers_for_loop(2, 3) == 2 + 2**2 + 2**3

Now, let’s look at the implementation using the scan function.

def sum_of_powers_scan(phi, h):
    def transition_fn(carry, phi):
        power_sum, power = carry
        power = power * phi
        power_sum = power_sum + power
        return (power_sum, power), power_sum

    (power_sum, _), _ = scan(f=transition_fn, init=(0, 1), xs=jnp.ones(h) * phi)
    return power_sum

We can verify that the two implementations give the same result.

assert sum_of_powers_scan(2, 0) == sum_of_powers_for_loop(2, 0)
assert sum_of_powers_scan(2, 1) == sum_of_powers_for_loop(2, 1)
assert sum_of_powers_scan(2, 2) == sum_of_powers_for_loop(2, 2)
assert sum_of_powers_scan(2, 3) == sum_of_powers_for_loop(2, 3)
assert sum_of_powers_scan(2, 10) == sum_of_powers_for_loop(2, 10)

There is another handy function to write efficient for loops in JAX: the fori_loop function. From the documentation we also get a pure Python implementation:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

In this case we can easily implement the sum of powers using the fori_loop function.

def sum_of_powers_fori_loop(phi, h):
    def body_fn(i, power_sum):
        return power_sum + phi**i

    return fori_loop(lower=1, upper=h + 1, body_fun=body_fn, init_val=0)
assert sum_of_powers_scan(2, 0) == sum_of_powers_for_loop(2, 0)
assert sum_of_powers_scan(2, 1) == sum_of_powers_for_loop(2, 1)
assert sum_of_powers_scan(2, 2) == sum_of_powers_for_loop(2, 2)
assert sum_of_powers_scan(2, 3) == sum_of_powers_for_loop(2, 3)
assert sum_of_powers_scan(2, 10) == sum_of_powers_for_loop(2, 10)

After this brief introduction to the scan function, we are ready to dive into the implementation of the exponential smoothing algorithm using NumPyro.


Generate Synthetic Data

We start by generating some synthetic data. We will use a very similar data generation process as in the example from the documentation. We generate a time series with a trend and a seasonal component.

n_seasons = 15
t = jnp.linspace(start=0, stop=n_seasons + 1, num=(n_seasons + 1) * n_seasons)
y = jnp.cos(2 * jnp.pi * t) + jnp.log(t + 1) + 0.2 * random.normal(rng_key, t.shape)

fig, ax = plt.subplots()
ax.plot(t, y)
ax.set(xlabel="time", ylabel="y", title="Time Series Data")

Train - Test Split

We now split the data into a training and a test set. We will use the training set to fit the model and the test set to evaluate the model’s performance.

n = y.size

prop_train = 0.8
n_train = round(prop_train * n)

y_train = y[:n_train]
t_train = t[:n_train]

y_test = y[n_train:]
t_test = t[n_train:]

fig, ax = plt.subplots()
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Time Series Data Split")

Level Model

We do not look and the complete model from start. Instead, we progressively study and add components. We start with the most fundamental component of the model: the level model. The level model is a simple model that predicts the next value in the time series based on the previous value. The model is defined by the following equations:

\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t \\ l_t = & \: \alpha y_t + (1 - \alpha)l_{t-1} \end{align*}\]

Here:

  • \(y_t\) is the observed value at time \(t\).
  • \(\hat{y}_{t+h|t}\) is the forecast of the value at time \(t+h\) given the information up to time \(t\).
  • \(l_t\) is the level at time \(t\).
  • \(\alpha\) is the smoothing parameter. It is a value between 0 and 1. A value of 1 means that the forecast is based only on the observed value at time \(t\). A value of 0 means that the forecast is based only on the previous level.

Note that the level equation is a simple weighted average of the observed value and the previous level (for more details see 8.1 Simple exponential smoothing). Moreover, note that this equation is perfectly suited for the scan function.

Model Specification

We start by specifying the level model in NumPyro. We use the same priors from the example in the documentation. Note that we place uniform priors on the smoothing parameter \(\alpha\).

fig, ax = plt.subplots()
pz.Beta(alpha=1, beta=1).plot_pdf(ax=ax)
ax.set(title="Beta(1, 1) Prior")
def level_model(y: ArrayImpl, future: int = 0) -> None:
    # Get time series length
    t_max = y.shape[0]

    # --- Priors ---

    ## Level
    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))

    ## Noise
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))

    # --- Transition Function ---

    def transition_fn(carry, t):
        previous_level = carry

        level = jnp.where(
            t < t_max,
            level_smoothing * y[t] + (1 - level_smoothing) * previous_level,
            previous_level,
        )

        mu = previous_level
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))

        return level, pred

    # --- Run Scan ---

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            level_init,
            jnp.arange(t_max + future),
        )

    # --- Forecast ---
    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])

Observe we are using the condition effect handler to condition the model on the observed data. As explained in the example from the documentation “Time Series Forecasting”:

The reason is we also want to use this model for forecasting. In forecasting, future values of y are non-observable, so obs=y[t] does not make sense when t >= len(y) (caution: index out-of-bound errors do not get raised in JAX, e.g. jnp.arange(3)[10] == 2). Using condition, when the length of scan is larger than the length of the conditioned/observed site, unobserved values will be sampled from the distribution of that site.

Inference

We now proceed to use MCMC to fit the level model to the training data. We define some helper function as we will apply them to many models.

class InferenceParams(BaseModel):
    num_warmup: int = Field(2_000, ge=1)
    num_samples: int = Field(2_000, ge=1)
    num_chains: int = Field(4, ge=1)


def run_inference(
    rng_key: ArrayImpl,
    model: Callable,
    args: InferenceParams,
    *model_args,
    **nuts_kwargs,
) -> MCMC:
    sampler = NUTS(model, **nuts_kwargs)
    mcmc = MCMC(
        sampler=sampler,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
    )
    mcmc.run(rng_key, *model_args)
    return mcmc

Let’s fit the model:

inference_params = InferenceParams()
rng_key, rng_subkey = random.split(key=rng_key)
level_mcmc = run_inference(rng_subkey, level_model, inference_params, y_train)

The diagnostics look good:

level_idata = az.from_numpyro(posterior=level_mcmc)

az.summary(data=level_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
level_init 1.013 0.396 0.275 1.778 0.005 0.004 5411.0 4042.0 1.0
level_smoothing 0.974 0.023 0.931 1.000 0.000 0.000 4779.0 3592.0 1.0
noise 0.430 0.022 0.392 0.472 0.000 0.000 6447.0 5120.0 1.0
print(f"""Divergences: {level_idata["sample_stats"]["diverging"].sum().item()}""")
Divergences: 0
axes = az.plot_trace(
    data=level_idata,
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Level Model Trace", fontsize=16)

Forecast

Now we can use the fitted model to forecast the test set. As above, we define a helper function for this purpose:

def forecast(
    rng_key: ArrayImpl, model: Callable, samples: dict[str, ArrayImpl], *model_args
) -> dict[str, ArrayImpl]:
    predictive = Predictive(
        model=model,
        posterior_samples=samples,
        return_sites=["y_forecast"],
    )
    return predictive(rng_key, *model_args)
rng_key, rng_subkey = random.split(key=rng_key)
level_forecast = forecast(
    rng_subkey, level_model, level_mcmc.get_samples(), y_train, y_test.size
)
level_posterior_predictive = az.from_numpyro(
    posterior_predictive=level_forecast,
    coords={"t": t_test},
    dims={"y_forecast": ["t"]},
)

We can now visualize the forecast:

fig, ax = plt.subplots()
az.plot_hdi(
    x=t_test,
    y=level_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.94,
    color="C2",
    fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=t_test,
    y=level_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.50,
    color="C2",
    fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
    ax=ax,
)
ax.plot(
    t_test,
    level_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
        dim=("chain", "draw")
    ),
    color="C2",
    label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Level Model Forecast")

As expected, the forecast is flat.


Level + Trend Model

Next, we add a trend component to the model. The trend model is defined by the following equations:

\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + hb_t \\ l_t = & \: \alpha y_t + (1 - \alpha)(l_{t-1} + b_{t-1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1-\beta^*)b_{t - 1} \end{align*}\]

Here \(b_t\) denotes the trend at time \(t\) and \(\beta^*\) is a smoothing parameter. The trend equation is a weighted average of the difference between the current level and the previous level and the previous trend.

Remark: The reason for the upper \(*\) in the \(\beta^*\) is just a result of the notation in the context of the state space representation of the model. See here.

Model Specification

Note that given the level model above, we can easily extend it to include the trend component:

def level_trend_model(y: ArrayImpl, future: int = 0) -> None:
    # Get time series length
    t_max = y.shape[0]

    # --- Priors ---

    ## Level
    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))

    ## Trend
    trend_smoothing = numpyro.sample(
        "trend_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    trend_init = numpyro.sample("trend_init", dist.Normal(loc=0, scale=1))

    ## Noise
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))

    # --- Transition Function ---

    def transition_fn(carry, t):
        previous_level, previous_trend = carry

        level = jnp.where(
            t < t_max,
            level_smoothing * y[t]
            + (1 - level_smoothing) * (previous_level + previous_trend),
            previous_level,
        )

        trend = jnp.where(
            t < t_max,
            trend_smoothing * (level - previous_level)
            + (1 - trend_smoothing) * previous_trend,
            previous_trend,
        )

        step = jnp.where(t < t_max, 1, t - t_max + 1)

        mu = previous_level + step * previous_trend
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))

        return (level, trend), pred

    # --- Run Scan ---

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            (level_init, trend_init),
            jnp.arange(t_max + future),
        )

    # --- Forecast ---
    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])

Inference

We fit the model as before:

rng_key, rng_subkey = random.split(key=rng_key)
level_trend_mcmc = run_inference(
    rng_subkey, level_trend_model, inference_params, y_train
)
level_trend_idata = az.from_numpyro(posterior=level_trend_mcmc)

az.summary(data=level_trend_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
level_init 1.224 0.378 0.518 1.928 0.005 0.004 4965.0 4681.0 1.0
level_smoothing 0.641 0.047 0.560 0.733 0.001 0.001 4481.0 3110.0 1.0
noise 0.430 0.023 0.388 0.473 0.000 0.000 5229.0 5336.0 1.0
trend_init 0.011 0.344 -0.619 0.664 0.005 0.004 4819.0 4917.0 1.0
trend_smoothing 0.912 0.077 0.771 1.000 0.001 0.001 3466.0 3557.0 1.0
print(f"""Divergences: {level_trend_idata["sample_stats"]["diverging"].sum().item()}""")
Divergences: 0
axes = az.plot_trace(
    data=level_trend_idata,
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 9), "layout": "constrained"},
)
plt.gcf().suptitle("Level + Trend Model Trace", fontsize=16)

Forecast

We generate the forecast:

rng_key, rng_subkey = random.split(key=rng_key)
level_trend_forecast = forecast(
    rng_subkey, level_trend_model, level_trend_mcmc.get_samples(), y_train, y_test.size
)

level_trend_posterior_predictive = az.from_numpyro(
    posterior_predictive=level_trend_forecast,
    coords={"t": t_test},
    dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
    x=t_test,
    y=level_trend_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.94,
    color="C2",
    fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=t_test,
    y=level_trend_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.50,
    color="C2",
    fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
    ax=ax,
)
ax.plot(
    t_test,
    level_trend_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
        dim=("chain", "draw")
    ),
    color="C2",
    label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Level + Trend Model Forecast")

Ups! This is literally just extrapolating with s straight line 😒. Still, this is often quite a good forecasting baseline for short term horizons. We are clearly missing the seasonality component to complement the trend model. This is what we do next!


Level + Trend + Seasonality Model (Holt-Winters)

We have finally arrived at the (additive) level + trend + seasonality model, also known as Holt-Winters model. We extend the trend model to include a seasonal component:

\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + hb_t + s_{t + h - m(k + 1)} \\ l_t = & \: \alpha(y_t - s_{t - m}) + (1 - \alpha)(l_{t - 1} + b_{t - 1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1 - \beta^*)b_{t - 1} \\ s_t = & \: \gamma(y_t - l_{t - 1} - b_{t - 1})+(1 - \gamma)s_{t - m} \end{align*}\]

We have added the seasonal component \(s_t\) with the corresponding smoothing parameter \(\gamma\). Here \(m\) denotes the number of seasons. The parameter \(k\) is the integer part of \((h - 1)/m\) (this just takes the latest seasonality estimate for this time point). for example, in the case \((h - 1)/m\) is an integer then

\[ t + h - m(k + 1) = t + h - m(h - 1)/m = t + h - (h - 1) = t + 1 \]

Remark: Similar to the note on thee \(\beta^*\) notation, the parameter \(\gamma\) is often called an adjusted smoothing as it is of the form \(\gamma=\gamma^{*}(1 - \alpha)\), so that \(0 \leq \gamma^{*} \leq 1\) translates to \(0 \leq \gamma \leq 1 - \alpha\). This is just a result of the state space representation of the model. See here.

Model Specification

We now use the model from the example:

def holt_winters_model(y: ArrayImpl, n_seasons: int, future: int = 0) -> None:
    # Get time series length
    t_max = y.shape[0]

    # --- Priors ---

    ## Level
    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))

    ## Trend
    trend_smoothing = numpyro.sample(
        "trend_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    trend_init = numpyro.sample("trend_init", dist.Normal(loc=0, scale=1))

    ## Seasonality
    seasonality_smoothing = numpyro.sample(
        "seasonality_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)

    with numpyro.plate("n_seasons", n_seasons):
        seasonality_init = numpyro.sample(
            "seasonality_init", dist.Normal(loc=0, scale=1)
        )

    ## Noise
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))

    # --- Transition Function ---

    def transition_fn(carry, t):
        previous_level, previous_trend, previous_seasonality = carry

        level = jnp.where(
            t < t_max,
            level_smoothing * (y[t] - previous_seasonality[0])
            + (1 - level_smoothing) * (previous_level + previous_trend),
            previous_level,
        )

        trend = jnp.where(
            t < t_max,
            trend_smoothing * (level - previous_level)
            + (1 - trend_smoothing) * previous_trend,
            previous_trend,
        )

        new_season = jnp.where(
            t < t_max,
            adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend))
            + (1 - adj_seasonality_smoothing) * previous_seasonality[0],
            previous_seasonality[0],
        )

        step = jnp.where(t < t_max, 1, t - t_max + 1)

        mu = previous_level + step * previous_trend + previous_seasonality[0]
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))

        seasonality = jnp.concatenate(
            [previous_seasonality[1:], new_season[None]], axis=0
        )

        return (level, trend, seasonality), pred

    # --- Run Scan ---

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            (level_init, trend_init, seasonality_init),
            jnp.arange(t_max + future),
        )

    # --- Forecast ---
    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])

Remark [Seasonality Looping]: Please note that the new_season variable in the transition_fn simply computes the next seasonal step as in the mathematical formulation of the model. Observe how the previous season s_{t - m} is given by previous_seasonality[0]. The index zero looks weird but it makes sense since we are looping the seasonality (carry) below as

seasonality = jnp.concatenate(
    [previous_seasonality[1:], new_season[None]], axis=0
)

so that the first item on such array always correspond to the desired previous seasonality in the iteration formula 💡.

Inference

Note that all of these models sample super fast!

rng_key, rng_subkey = random.split(key=rng_key)
holt_winters_mcmc = run_inference(
    rng_subkey,
    holt_winters_model,
    inference_params,
    y_train,
    n_seasons,
)
holt_winters_idata = az.from_numpyro(posterior=holt_winters_mcmc)

az.summary(data=holt_winters_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
level_init 0.162 0.278 -0.359 0.681 0.008 0.006 1153.0 1703.0 1.0
level_smoothing 0.117 0.050 0.024 0.209 0.001 0.001 2256.0 2089.0 1.0
noise 0.256 0.014 0.231 0.282 0.000 0.000 3276.0 4369.0 1.0
seasonality_init[0] 0.965 0.296 0.433 1.545 0.008 0.006 1317.0 1721.0 1.0
seasonality_init[1] 0.908 0.300 0.372 1.496 0.008 0.006 1271.0 1967.0 1.0
seasonality_init[2] 0.542 0.297 -0.023 1.091 0.008 0.006 1299.0 1953.0 1.0
seasonality_init[3] 0.258 0.296 -0.282 0.818 0.008 0.006 1310.0 1900.0 1.0
seasonality_init[4] 0.021 0.298 -0.544 0.562 0.008 0.006 1317.0 1961.0 1.0
seasonality_init[5] -0.661 0.297 -1.246 -0.126 0.008 0.006 1299.0 1983.0 1.0
seasonality_init[6] -0.836 0.299 -1.375 -0.258 0.008 0.006 1305.0 1993.0 1.0
seasonality_init[7] -0.989 0.295 -1.554 -0.440 0.008 0.006 1260.0 1995.0 1.0
seasonality_init[8] -0.773 0.293 -1.330 -0.220 0.008 0.006 1324.0 2162.0 1.0
seasonality_init[9] -0.824 0.295 -1.370 -0.266 0.008 0.006 1277.0 1950.0 1.0
seasonality_init[10] -0.542 0.299 -1.112 0.005 0.009 0.006 1208.0 1912.0 1.0
seasonality_init[11] -0.091 0.299 -0.642 0.475 0.008 0.006 1354.0 2243.0 1.0
seasonality_init[12] 0.387 0.296 -0.173 0.939 0.008 0.006 1309.0 2039.0 1.0
seasonality_init[13] 0.663 0.297 0.111 1.225 0.008 0.006 1261.0 1827.0 1.0
seasonality_init[14] 0.910 0.298 0.349 1.459 0.008 0.006 1311.0 1942.0 1.0
seasonality_smoothing 0.299 0.082 0.145 0.454 0.002 0.001 2588.0 1663.0 1.0
trend_init 0.029 0.012 0.007 0.053 0.000 0.000 2908.0 3843.0 1.0
trend_smoothing 0.069 0.063 0.000 0.145 0.002 0.001 2153.0 2119.0 1.0
print(
    f"""Divergences: {holt_winters_idata["sample_stats"]["diverging"].sum().item()}"""
)
Divergences: 0
axes = az.plot_trace(
    data=holt_winters_idata,
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 12), "layout": "constrained"},
)
plt.gcf().suptitle("Holt Winters Trace", fontsize=16)

The diagnostics and the trace plot look good!

Forecast

Let’s see if this model can forecast the test set better than the previous models:

rng_key, rng_subkey = random.split(key=rng_key)
holt_winters_forecast = forecast(
    rng_subkey,
    holt_winters_model,
    holt_winters_mcmc.get_samples(),
    y_train,
    n_seasons,
    y_test.size,
)

holt_winters_posterior_predictive = az.from_numpyro(
    posterior_predictive=holt_winters_forecast,
    coords={"t": t_test},
    dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
    x=t_test,
    y=holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.94,
    color="C2",
    fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=t_test,
    y=holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.50,
    color="C2",
    fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
    ax=ax,
)
ax.plot(
    t_test,
    holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
        dim=("chain", "draw")
    ),
    color="C2",
    label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Holt Winters Model Forecast")

Indeed! The forecast look pretty good 🚀! The point forecast is slighly over-forecasting, but still withing the credible range. Let’s see how we can make the trend forecast a bit more conservative 🤓.


Damped Holt-Winters

As a final model, we present an extension of the model above. We add a damping parameter \(\phi\) to the trend component. The damped Holt-Winters model is defined by the following equations:

\[\begin{align*} \hat{y}_{t+h|t} = & \: l_t + \phi_{h}b_t + s_{t + h - m(k + 1)} \\ l_t = & \: \alpha(y_t - s_{t - m}) + (1 - \alpha)(l_{t - 1} + \phi b_{t - 1}) \\ b_t = & \: \beta^*(l_t - l_{t - 1}) + (1 - \beta^*)\phi b_{t - 1} \\ s_t = & \: \gamma(y_t - l_{t - 1} - \phi b_{t - 1})+(1 - \gamma)s_{t - m} \end{align*}\]

where \(0 \leq \phi \leq 1\) and \(\phi_{h} := \phi + \phi^2 + \ldots + \phi^{h}\).

Remark: Now you see why we were interested in the power sum scan example at the beginning of these notes 😅.

Model Specification

Substituting \(b_{t -1} \longmapsto \phi b_{t -1}\) is straightforward. What needs a bit more attention is the \(\phi_{h}\) term. The for loop implementation is a no=go. The can one actually does not work as the size of the xs argument is dynamic. Heceen, the fori_loop function is thee way to go. Moreover, in order to help the sampler, wee impose more informative priors on the smoothing parameters.

fig, ax = plt.subplots()
pz.Beta(alpha=1, beta=1).plot_pdf(ax=ax)
pz.Beta(alpha=2, beta=2).plot_pdf(ax=ax)
pz.Beta(alpha=2, beta=5).plot_pdf(ax=ax)
ax.set(title="Beta Priors")
def damped_holt_winters_model(y: ArrayImpl, n_seasons: int, future: int = 0) -> None:
    # Get time series length
    t_max = y.shape[0]

    # --- Priors ---

    ## Level
    level_smoothing = numpyro.sample("level_smoothing", dist.Beta(2, 2))
    level_init = numpyro.sample("level_init", dist.Normal(0, 1))

    ## Trend
    trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(2, 2))
    trend_init = numpyro.sample("trend_init", dist.Normal(0.5, 1))

    ## Seasonality
    seasonality_smoothing = numpyro.sample("seasonality_smoothing", dist.Beta(2, 2))
    adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)

    ## Damping
    phi = numpyro.sample("phi", dist.Beta(2, 5))

    with numpyro.plate("n_seasons", n_seasons):
        seasonality_init = numpyro.sample("seasonality_init", dist.Normal(0, 1))

    ## Noise
    noise = numpyro.sample("noise", dist.HalfNormal(1))

    # --- Transition Function ---

    def transition_fn(carry, t):
        previous_level, previous_trend, previous_seasonality = carry

        level = jnp.where(
            t < t_max,
            level_smoothing * (y[t] - previous_seasonality[0])
            + (1 - level_smoothing) * (previous_level + phi * previous_trend),
            previous_level,
        )

        trend = jnp.where(
            t < t_max,
            trend_smoothing * (level - previous_level)
            + (1 - trend_smoothing) * phi * previous_trend,
            previous_trend,
        )

        new_season = jnp.where(
            t < t_max,
            adj_seasonality_smoothing * (y[t] - (previous_level + phi * previous_trend))
            + (1 - adj_seasonality_smoothing) * previous_seasonality[0],
            previous_seasonality[0],
        )

        step = jnp.where(t < t_max, 1, t - t_max + 1)
        phi_step = fori_loop(
            lower=1, upper=step + 1, body_fun=lambda i, val: val + phi**i, init_val=0
        )

        mu = previous_level + phi_step * previous_trend + previous_seasonality[0]
        pred = numpyro.sample("pred", dist.Normal(mu, noise))

        seasonality = jnp.concatenate(
            [previous_seasonality[1:], new_season[None]], axis=0
        )

        return (level, trend, seasonality), pred

    # --- Run Scan ---

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            (level_init, trend_init, seasonality_init),
            jnp.arange(t_max + future),
        )

    # --- Forecast ---
    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])

Inference

An important note is that the sampler needs the parameter forward_mode_differentiation as we are using the fori_loop function. This is because the fori_loop function is not compatible with reverse mode differentiation. This is a bit of a bummer as reverse mode differentiation is often more efficient. However, the sampler still runs pretty fast.

rng_key, rng_subkey = random.split(key=rng_key)
damped_holt_winters_mcmc = run_inference(
    rng_subkey,
    damped_holt_winters_model,
    inference_params,
    y_train,
    n_seasons,
    target_accept_prob=0.95,
    forward_mode_differentiation=True,
)
damped_holt_winters_idata = az.from_numpyro(posterior=damped_holt_winters_mcmc)

az.summary(data=damped_holt_winters_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
level_init 0.141 0.375 -0.575 0.836 0.011 0.008 1237.0 3017.0 1.01
level_smoothing 0.224 0.050 0.134 0.320 0.001 0.001 3590.0 4847.0 1.00
noise 0.262 0.014 0.235 0.288 0.000 0.000 4121.0 4769.0 1.00
phi 0.241 0.132 0.010 0.471 0.002 0.001 3654.0 3599.0 1.00
seasonality_init[0] 0.944 0.289 0.414 1.501 0.010 0.007 803.0 1878.0 1.01
seasonality_init[1] 0.893 0.291 0.385 1.473 0.010 0.007 807.0 2189.0 1.01
seasonality_init[2] 0.543 0.292 0.003 1.087 0.010 0.007 797.0 2162.0 1.01
seasonality_init[3] 0.275 0.289 -0.254 0.829 0.010 0.007 803.0 1939.0 1.01
seasonality_init[4] 0.053 0.296 -0.509 0.595 0.010 0.007 819.0 2140.0 1.01
seasonality_init[5] -0.602 0.286 -1.117 -0.050 0.010 0.007 792.0 2124.0 1.01
seasonality_init[6] -0.771 0.288 -1.294 -0.211 0.010 0.007 819.0 2157.0 1.01
seasonality_init[7] -0.928 0.287 -1.448 -0.365 0.010 0.007 774.0 1975.0 1.01
seasonality_init[8] -0.713 0.289 -1.261 -0.179 0.010 0.007 866.0 2105.0 1.01
seasonality_init[9] -0.759 0.289 -1.302 -0.215 0.011 0.008 743.0 1861.0 1.01
seasonality_init[10] -0.479 0.292 -1.038 0.053 0.010 0.007 813.0 1992.0 1.01
seasonality_init[11] -0.042 0.298 -0.603 0.514 0.010 0.007 873.0 1890.0 1.01
seasonality_init[12] 0.406 0.289 -0.108 0.985 0.010 0.007 791.0 1959.0 1.01
seasonality_init[13] 0.670 0.291 0.152 1.245 0.010 0.007 824.0 1971.0 1.01
seasonality_init[14] 0.895 0.290 0.366 1.432 0.010 0.007 790.0 2190.0 1.01
seasonality_smoothing 0.300 0.093 0.121 0.480 0.002 0.001 2617.0 1748.0 1.00
trend_init 0.457 0.880 -1.215 2.138 0.015 0.010 3642.0 4996.0 1.00
trend_smoothing 0.481 0.222 0.079 0.864 0.003 0.002 5294.0 4586.0 1.00
print(
    f"""Divergences: {
        damped_holt_winters_idata["sample_stats"]["diverging"].sum().item()
    }
"""
)
Divergences: 0
axes = az.plot_trace(
    data=damped_holt_winters_idata,
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 12), "layout": "constrained"},
)
plt.gcf().suptitle("Holt Winters Trace", fontsize=16)

We do not get any divergences 🙌!

Forecast

Finally, we generate the forecast for the damped Holt-Winters model:

rng_key, rng_subkey = random.split(key=rng_key)
damped_holt_winters_forecast = forecast(
    rng_subkey,
    damped_holt_winters_model,
    holt_winters_mcmc.get_samples(),
    y_train,
    n_seasons,
    y_test.size,
)

damped_holt_winters_posterior_predictive = az.from_numpyro(
    posterior_predictive=damped_holt_winters_forecast,
    coords={"t": t_test},
    dims={"y_forecast": ["t"]},
)
fig, ax = plt.subplots()
az.plot_hdi(
    x=t_test,
    y=damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.94,
    color="C2",
    fill_kwargs={"alpha": 0.2, "label": r"$94\%$ HDI"},
    ax=ax,
)
az.plot_hdi(
    x=t_test,
    y=damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"],
    hdi_prob=0.50,
    color="C2",
    fill_kwargs={"alpha": 0.5, "label": r"$50\%$ HDI"},
    ax=ax,
)
ax.plot(
    t_test,
    damped_holt_winters_posterior_predictive["posterior_predictive"]["y_forecast"].mean(
        dim=("chain", "draw")
    ),
    color="C2",
    label="mean forecast",
)
ax.plot(t_train, y_train, color="C0", label="train")
ax.plot(t_test, y_test, color="C1", label="test")
ax.axvline(x=t_train[-1], c="black", linestyle="--")
ax.legend()
ax.set(xlabel="time", ylabel="y", title="Damped Holt Winters Model Forecast")

We indeed see a milder trend forecast as expected (note the posterior distribution of \(\phi\) is far from \(1\))!


This was a great learning exercise for me. It provided me intuition and practice to tackle more complex models like the one presented in the documentation: “Time Series Forecasting: Seasonal, Global Trend (SGT)”.

I hope you also found it useful!