17 min read

Bayesian Vector Autoregressive Models in NumPyro

In this notebook, we present how to implement and fit Bayesian Vector Autoregressive (VAR) models using NumPyro. We work out three components:

  1. Specifying and fitting the model in NumPyro
  2. Using the model to generate forecasts
  3. Computing the Impulse Response Functions (IRFs)

We compare these three components with the ones obtained using the statsmodels implementation from the Vector Autoregressions tsa.vector_ar tutorial.

Prepare Notebook

from functools import partial

import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import statsmodels.api as sm
import xarray as xr
from jax import jit, lax, random, vmap
from jaxtyping import Array, Float
from numpyro.contrib.control_flow import scan
from numpyro.handlers import condition
from numpyro.infer import MCMC, NUTS
from statsmodels.tsa.api import VAR
from statsmodels.tsa.base.datetools import dates_from_str

numpyro.set_host_device_count(n=10)

rng_key = random.PRNGKey(seed=42)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype
%config InlineBackend.figure_format = "retina"

Load Data

We are going to use a dataset from the statsmodels package. Specifically, we will use the macrodata dataset from Vector Autoregressions tsa.vector_ar tutorial. For the sake of reproducibility, we will keep the exact same code as in the tutorial.

def load_data() -> pd.DataFrame:
    mdata = sm.datasets.macrodata.load_pandas().data
    dates = mdata[["year", "quarter"]].astype(int).astype(str)
    quarterly = dates["year"] + "Q" + dates["quarter"]
    quarterly = dates_from_str(quarterly)
    mdata = mdata[["realgdp", "realcons", "realinv"]]
    mdata.index = pd.DatetimeIndex(quarterly, freq="QE")
    return np.log(mdata).diff().dropna()


data: pd.DataFrame = load_data()

We start by visualizing the data.

fig, ax = plt.subplots()
data.plot(ax=ax)
ax.set(xlabel="Time", ylabel="Log Differenced Data")
ax.set_title("Log Differenced Data", fontsize=18, fontweight="bold");

It looks like the data is stationary.

Fit VAR Model with Statsmodels

Recall that a \(\text{VAR}(p)\) model can be written as:

\[ Y_t = c + \sum_{j=1}^{p} \Phi_j Y_{t-j} + \varepsilon_t, \]

where \(c\) is a vector of constants, \(\Phi_j\) is the coefficient matrix for the \(j\)-th lag, and \(\varepsilon_t\) is the error term which is a vector of i.i.d. normal random variables with mean \(0\) and covariance matrix \(\Sigma\). Each matrix \(\Phi_j\) has dimensions \((k, k)\) where \(k\) is the number of variables in the model. Let \(\Phi = [\Phi_1, \Phi_2, \ldots, \Phi_p]\) be the tensor of coefficient matrices of shape \((p, k, k)\) so that we can write the model in vectorized form as: \[ Y_t = c + \Phi \times \begin{bmatrix} Y_{t-1} \\ Y_{t-2} \\ \vdots \\ Y_{t-p} \end{bmatrix} + \varepsilon_t, \]

Before implementing the model in Numpyro, we fit a VAR model using the statsmodels package to get the reference values.

var_model = VAR(data)

var_results = var_model.fit(maxlags=2)

var_results.summary()
  Summary of Regression Results   
==================================
Model:                         VAR
Method:                        OLS
Date:           Fri, 03, Oct, 2025
Time:                     14:59:20
--------------------------------------------------------------------
No. of Equations:         3.00000    BIC:                   -27.5830
Nobs:                     200.000    HQIC:                  -27.7892
Log likelihood:           1962.57    FPE:                7.42129e-13
AIC:                     -27.9293    Det(Omega_mle):     6.69358e-13
--------------------------------------------------------------------
Results for equation realgdp
==============================================================================
                 coefficient       std. error           t-stat            prob
------------------------------------------------------------------------------
const               0.001527         0.001119            1.365           0.172
L1.realgdp         -0.279435         0.169663           -1.647           0.100
L1.realcons         0.675016         0.131285            5.142           0.000
L1.realinv          0.033219         0.026194            1.268           0.205
L2.realgdp          0.008221         0.173522            0.047           0.962
L2.realcons         0.290458         0.145904            1.991           0.047
L2.realinv         -0.007321         0.025786           -0.284           0.776
==============================================================================

Results for equation realcons
==============================================================================
                 coefficient       std. error           t-stat            prob
------------------------------------------------------------------------------
const               0.005460         0.000969            5.634           0.000
L1.realgdp         -0.100468         0.146924           -0.684           0.494
L1.realcons         0.268640         0.113690            2.363           0.018
L1.realinv          0.025739         0.022683            1.135           0.257
L2.realgdp         -0.123174         0.150267           -0.820           0.412
L2.realcons         0.232499         0.126350            1.840           0.066
L2.realinv          0.023504         0.022330            1.053           0.293
==============================================================================

Results for equation realinv
==============================================================================
                 coefficient       std. error           t-stat            prob
------------------------------------------------------------------------------
const              -0.023903         0.005863           -4.077           0.000
L1.realgdp         -1.970974         0.888892           -2.217           0.027
L1.realcons         4.414162         0.687825            6.418           0.000
L1.realinv          0.225479         0.137234            1.643           0.100
L2.realgdp          0.380786         0.909114            0.419           0.675
L2.realcons         0.800281         0.764416            1.047           0.295
L2.realinv         -0.124079         0.135098           -0.918           0.358
==============================================================================

Correlation matrix of residuals
             realgdp  realcons   realinv
realgdp     1.000000  0.603316  0.750722
realcons    0.603316  1.000000  0.131951
realinv     0.750722  0.131951  1.000000

NumPyro Implementation

Next, we implement the model in Numpyro. The core idea is taken from the NumPyro docs: Example: VAR(2) process. In our implementation, we make it in such a way that we vectorize the computation over the lags components (i.e. this works for lags larger than \(2\)).

Vectorization Over Lags

The vectorization is a bit tricky at first. So before jumping into the model, we consider a simple example. The idea is to vectorize the coefficient matrix \(\Phi\) over the lags \(j=1, \ldots, p\). Let us consider the case of \(p=2\) and generate a synthetic matrix \(\Phi\) as we mainly care about the computation and not the values themselves for now.

# number of variables (taken from the data)
n_vars = data.shape[1]
# number of lags
n_lags = 2

# generate a synthetic matrix
phi = jnp.arange(n_lags * n_vars * n_vars).reshape(n_lags, n_vars, n_vars)

phi
Array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]], dtype=int32)

For the matrix \(\Phi\), the first dimension is the lags, the second dimension is the variables on the rows and the third dimension is the variables on the columns that we want to multiply and sum.

Next, we consider a synthetic lags vector which would represent the lags of the dependent variable \(Y_t\).

y_lags = 2 * jnp.ones((n_lags, n_vars))

y_lags
Array([[2., 2., 2.],
       [2., 2., 2.]], dtype=float32)

The tensor operation we want to perform is:

# Element-wise multiplication and sum over the third dimension (columns).
(phi * y_lags[..., jnp.newaxis]).sum(axis=(0, 2))
Array([ 66., 102., 138.], dtype=float32)

This can be achieved by using the einsum function with a proper specification of the indices.

jnp.einsum("lij,lj->i", phi, y_lags)
Array([ 66., 102., 138.], dtype=float32)

We will use the einsum function to perform the operation in the NumPyro model.

NumPyro Model

We are now ready to implement the VAR model in NumPyro. The idea is to use the scan function as in the previous example Notes on an ARMA(1, 1) Model with NumPyro, see also the PyData Amsterdam video Time Series forecasting with NumPyro for more details.

def model(y: Float[Array, "time vars"], n_lags: int, future: int = 0) -> None:
    # Get the number of time steps and variables
    n_time, n_vars = y.shape

    # --- Priors ---
    constant = numpyro.sample("constant", dist.Normal(loc=0, scale=1).expand([n_vars]))

    sigma = numpyro.sample("sigma", dist.HalfNormal(scale=1.0).expand([n_vars]))

    l_omega = numpyro.sample(
        "l_omega", dist.LKJCholesky(dimension=n_vars, concentration=1.0)
    )
    l_sigma = jnp.einsum("...i,...ij->...ij", sigma, l_omega)

    # Sample phi coefficients - shape (n_lags, n_vars, n_vars)
    # The first dimension is the lags, the second dimension is the variables on the rows
    # and the third dimension is the variables on the columns that
    # we want to multiply and sum.
    phi = numpyro.sample(
        "phi", dist.Normal(loc=0, scale=10).expand([n_lags, n_vars, n_vars]).to_event(3)
    )

    # --- Transition Function ---
    def transition_fn(carry: Array, _, name: str) -> tuple[Array, Array]:
        # carry: (n_lags, n_vars)
        y_lags = carry

        # Compute lag contributions as a matrix product of phi and y_lags
        # (see the example above!)
        # Here the only trick is to reverse the phi lag coordinates. Why?
        # The first entry in the  initial `carry` vector `init_carry = y[:n_lags]`
        # is the oldest lag and the last entry is the newest lag.
        lag_contributions = jnp.einsum("lij,lj->i", phi[::-1], y_lags)
        # Compute VAR mean
        m_t = constant + lag_contributions

        # Sample observation
        y_t = numpyro.sample(name, dist.MultivariateNormal(loc=m_t, scale_tril=l_sigma))

        # Update carry: remove oldest, add newest
        new_carry = jnp.concatenate([y_lags[1:], y_t[None, :]], axis=0)
        return new_carry, y_t

    inference_fn = partial(transition_fn, name="y_pred")

    # Initialize and run scan
    init_carry = y[:n_lags]
    time_indices = jnp.arange(n_lags, n_time)

    with condition(data={"y_pred": y[n_lags:]}):
        scan(inference_fn, init=init_carry, xs=time_indices)
    
    # Generate forecast
    if future > 0:
        prediction_fn = partial(transition_fn, name="y_future")
        scan(prediction_fn, init=y[-n_lags:], xs=jnp.arange(future))

Let’s visualize the model:

y: Float[Array, "time vars"] = jnp.array(data)

numpyro.render_model(
    model,
    model_kwargs={"y": y, "n_lags": 2, "future": 10},
    render_distributions=True,
    render_params=True,
)

Here we see the model structure and the two outputs: y_pred (inference) and y_future (prediction).

Fit Numpyro Model

We now sample from the posterior distribution of the model using MCMC.

%%time

nuts_kernel = NUTS(model)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1_000,
    num_samples=1_000,
    num_chains=4,
)

# Run inference
n_lags = 2
future = 30
rng_key, rng_subkey = random.split(rng_key)
mcmc.run(rng_subkey, y=y, n_lags=n_lags, future=future)

# Get samples
samples = mcmc.get_samples()
CPU times: user 2min 22s, sys: 8.15 s, total: 2min 30s
Wall time: 39.8 s

Let’s parse the samples to an ArviZ InferenceData object.

idata = az.from_numpyro(
    mcmc,
    coords={
        "var_1": data.columns,
        "var_2": data.columns,
        "lag": range(1, n_lags + 1),
        "future": range(data.shape[0], data.shape[0] + future),
    },
    dims={
        "constant": ["var_1"],
        "sigma": ["var_1"],
        "l_omega": ["var_1", "var_2"],
        "phi": ["lag", "var_1", "var_2"],
        "y_future": ["future", "var_1"],
    },
)

We can now visualize the traces.

axes = az.plot_trace(
    data=idata,
    var_names=["~y_future"],
    compact=True,
    backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Trace", fontsize=18, fontweight="bold");

Overall, the chains seem to converge well.

Parameter Comparison

We can manually inspect certain parameters mean values to see if they match the reference values from the statsmodels results.

(
    idata["posterior"]["phi"]
    .mean(dim=["chain", "draw"])
    .sel(var_1="realgdp")
    .to_dataframe()
    .sort_index()
)
var_1 phi
lag var_2
1 realcons realgdp 0.668054
realgdp realgdp -0.272790
realinv realgdp 0.032299
2 realcons realgdp 0.297236
realgdp realgdp -0.000703
realinv realgdp -0.006114
var_results.params["realgdp"].to_frame()
realgdp
const 0.001527
L1.realgdp -0.279435
L1.realcons 0.675016
L1.realinv 0.033219
L2.realgdp 0.008221
L2.realcons 0.290458
L2.realinv -0.007321

We do see the values are very close 🚀!

Similarly, we can look into the correlation matrix:

l_omega_mean = idata["posterior"]["l_omega"].mean(dim=["chain", "draw"])

corr_mean = l_omega_mean.to_numpy() @ l_omega_mean.to_numpy().T

corr_mean
array([[1.        , 0.59747654, 0.745766  ],
       [0.59747654, 0.9966769 , 0.12399177],
       [0.745766  , 0.12399177, 0.9963025 ]], dtype=float32)

Which is very close to the correlation matrix from the statsmodels summary results above.

Finally, we can compare all the \(\Phi\) parameters with the estimated posterior distributions.

# Reference values from statsmodels VAR(2) results
# Simplified construction using list comprehension and f-strings
lags = range(2, 0, -1)
variables = data.columns

ref_vals_phi = {
    "phi": [
        {
            "lag": lag,
            "var_1": var_1,
            "var_2": var_2,
            "ref_val": var_results.params[var_1][f"L{lag}.{var_2}"],
        }
        for lag in lags
        for var_1 in variables
        for var_2 in variables
    ]
}

axes = az.plot_posterior(
    idata,
    var_names=["phi"],
    ref_val=ref_vals_phi,
    figsize=(18, 20),
)
fig = axes[0][0].figure
fig.suptitle(
    r"Posterior Distribution of $\Phi$ Parameters",
    fontsize=18,
    fontweight="bold",
    y=1.02,
);

All of the estimated parameters are very close to the reference values!

fig = var_results.plot_forecast(steps=future, alpha=0.05, plot_stderr=True)
axes = fig.get_axes()

t_future = idata["posterior"]["y_future"].coords["future"]

for var_idx, var in enumerate(data.columns):
    ax = axes[var_idx]
    for i, hdi_prob in enumerate([0.94, 0.5]):
        az.plot_hdi(
            t_future,
            idata["posterior"]["y_future"].sel(var_1=var),
            color="C0",
            hdi_prob=hdi_prob,
            fill_kwargs={"alpha": 0.2 + 0.2 * i},
            ax=ax,
        )

    ax.plot(
        t_future,
        idata["posterior"]["y_future"].sel(var_1=var).mean(dim=["chain", "draw"]),
        color="C0",
    )

fig.suptitle(
    "Forecast VAR(2) Model\nStatsmodels vs Numpyro",
    fontsize=18,
    fontweight="bold",
    y=1.06,
);

Impulse Response Functions (IRFs)

After fitting a VAR model, we often want to understand how the system responds to shocks. This is where Impulse Response Functions (IRFs) come in.

The Intuition

Imagine you have three economic variables: GDP, consumption, and investment. Now suppose there’s an unexpected shock to GDP (e.g., a sudden policy change). An IRF answers questions like:

  • How does GDP itself respond over time?
  • How does consumption react to this GDP shock?
  • What happens to investment in subsequent periods?

IRFs trace out the dynamic response of each variable to a one-time shock in another variable, holding all else constant.

From VAR to MA Representation

Recall our \(\text{VAR}(p)\) model:

\[Y_t = c + \sum_{j=1}^{p} \Phi_j Y_{t-j} + \varepsilon_t\]

This model tells us how current values depend on past values. But to compute IRFs, we need the Moving Average (\(\text{MA}(\infty)\)) representation, which expresses current values in terms of current and past shocks:

\[Y_t = \mu + \sum_{i=0}^{\infty} \Psi_i \varepsilon_{t-i}\]

where:

  • \(\Psi_0 = I\) (identity matrix - a shock has immediate unit effect on itself)
  • \(\Psi_i\) are the MA coefficient matrices that tell us the response at time \(i\) to a shock at time 0
  • These \(\Psi_i\) matrices are the IRFs!

The Recursive Algorithm

A key result of the VAR model is that we can compute the \(\Psi_i\) matrices recursively from the VAR coefficients \(\Phi_j\):

\[\Psi_0 = I\]

\[\Psi_i = \sum_{j=1}^{\min(i, p)} \Psi_{i-j} \Phi_j \quad \text{for } i \geq 1\]

This means:

  • At time 0: The response is just the identity (shock = response)
  • At time 1: \(\Psi_1 = \Psi_0 \Phi_1 = \Phi_1\) (first-order effects)
  • At time 2: \(\Psi_2 = \Psi_1 \Phi_1 + \Psi_0 \Phi_2\) (effects compound!)
  • And so on…

Each \(\Psi_i[k, j]\) tells us:

“What is the response of variable k at time i to a unit shock in variable j at time 0?”

Deriving the Recursive Formula

The recursive formula for computing IRFs comes from substituting the VAR representation into itself. Here’s the intuition:

Starting with the VAR(p) model:

\[Y_t = c + \Phi_1 Y_{t-1} + \Phi_2 Y_{t-2} + \cdots + \Phi_p Y_{t-p} + \varepsilon_t\]

We want to express \(Y_t\) purely in terms of shocks \(\varepsilon_t, \varepsilon_{t-1}, \varepsilon_{t-2}, \ldots\)

The key insight: Keep substituting past values with their own VAR equations!

  • At \(t\): \(Y_t = c + \Phi_1 Y_{t-1} + \Phi_2 Y_{t-2} + \cdots + \varepsilon_t\)
  • For \(Y_{t-1}\): substitute its VAR equation, which brings in \(\varepsilon_{t-1}\)
  • For \(Y_{t-2}\): substitute its VAR equation, which brings in \(\varepsilon_{t-2}\)
  • Continue infinitely…

After all substitutions and collecting terms by shock timing, we get:

\[Y_t = \mu + \varepsilon_t + \Psi_1 \varepsilon_{t-1} + \Psi_2 \varepsilon_{t-2} + \cdots\]

The recursion emerges because:

  • The coefficient on \(\varepsilon_{t-i}\) (which is \(\Psi_i\)) depends on how \(Y_{t-1}, Y_{t-2}, \ldots\) responded to that same shock in earlier periods
  • Specifically: \(\Psi_i\) accumulates contributions from \(\Phi_1 \Psi_{i-1}\) (via \(Y_{t-1}\)), \(\Phi_2 \Psi_{i-2}\) (via \(Y_{t-2}\)), etc.

This gives us the recursive relationship above.

The \(\min(i, p)\) appears because we only have \(p\) lags in the VAR - there’s no \(\Phi_j\) for \(j > p\).

Implementation Strategy

We will implement the impulse response function using the compute_irf function below. This function implements this recursive algorithm efficiently using:

  1. JAX’s lax.scan for fast, functional iteration
  2. JIT compilation for maximum speed
  3. Vectorization to compute all responses simultaneously across all posterior samples

Let’s see how it works! 👇

Impulse Response Functions

def compute_irf(
    phi: Float[Array, "*sample n_lags n_vars n_vars"],
    n_steps: int,
    shock_size: float = 1.0,
) -> Float[Array, "*sample n_steps n_vars n_vars"]:
    """
    Compute MA(∞) representation of VAR(p) process (non-orthogonalized IRF).

    Implements the recursive algorithm using jax.lax.scan:

    Ψ_0 = I

    Ψ_i = sum_{j=1}^{min(i,p)} Ψ_{i-j} @ Φ_j for i >= 1

    Parameters
    ----------
    phi : array of shape (n_lags, n_vars, n_vars)
        VAR coefficient matrices Φ_j. phi[j-1] corresponds to Φ_j.
    n_steps : int
        Number of MA coefficient matrices to compute.
    shock_size : float, default=1.0
        Scaling factor for identity matrix at t=0.

    Returns
    -------
    psi : array of shape (n_steps, n_vars, n_vars)
        MA representation (IRF matrices). psis[i, :, j] is the response of all variables
        at time i to a unit shock to variable j at time 0.
    """

    n_lags, n_vars, _ = phi.shape

    def scan_fn(carry: Array, i: Array) -> tuple[Array, Array]:
        """
        Compute Ψ_i from previous MA matrices.

        carry: Array of shape (n_lags, n_vars, n_vars) containing the last n_lags MA
            matrices. carry[0] is Ψ_{i-1}, carry[1] is Ψ_{i-2}, ..., carry[n_lags - 1]
            is Ψ_{i-n_lags}
        i: current time step
        """
        # Compute Ψ_i = sum_{j=1}^{min(i,p)} Ψ_{i-j} @ Φ_j
        # We need to handle the case where i < n_lags (early steps)

        # carry[0] is Ψ_{i-1}, carry[1] is Ψ_{i-2}, etc.
        # phi[0] is Φ_1 (lag 1), phi[1] is Φ_2 (lag 2), etc.

        # For each lag j from 1 to min(i, n_lags):
        #   Ψ_{i-j} is carry[j - 1]
        #   Φ_j is phi[j - 1]

        # Create a mask to only sum over valid lags (up to min(i, n_lags))
        valid_lags = jnp.arange(n_lags) < jnp.minimum(i, n_lags)

        # Compute contributions: Ψ_{i-j} @ Φ_j for each j
        # carry[j] @ phi[j] for j in range(n_lags)
        contributions = jnp.einsum("jkl,jlm->jkm", carry, phi)

        # Mask invalid contributions and sum
        psi_i = jnp.sum(contributions * valid_lags[:, None, None], axis=0)

        # Update carry: shift everything by 1 and add new Ψ_i at the front
        new_carry = jnp.concatenate([psi_i[None, :, :], carry[:-1]], axis=0)

        return new_carry, psi_i

    # Initialize carry with zeros and set Ψ_0 = I at the front
    psi_0 = shock_size * jnp.eye(n_vars)
    init_carry = jnp.concatenate(
        [psi_0[None, :, :], jnp.zeros((n_lags - 1, n_vars, n_vars))], axis=0
    )

    # Run scan for steps 1 to n_steps - 1
    if n_steps == 1:
        return psi_0[None, :, :]

    time_steps = jnp.arange(1, n_steps)
    _, psis_rest = lax.scan(scan_fn, init_carry, time_steps)

    # Concatenate Ψ_0 with the rest
    return jnp.concatenate([psi_0[None, :, :], psis_rest], axis=0)


compute_irf_jit = jit(
    compute_irf,
    static_argnames=["n_steps", "shock_size"],  # For this example is enough
)

Let’s verify that this implementation matches the statsmodels results.

n_irf_steps = 10

phi_sm = jnp.array(var_results.coefs)

# Get statsmodels IRF (ma_rep)
irf_sm = var_results.ma_rep(maxn=n_irf_steps - 1)

# Compute IRF with scan-based function
irf_jax_scan = compute_irf_jit(phi_sm, n_steps=n_irf_steps)

# Check results match
assert jnp.allclose(irf_jax_scan, irf_sm)

Great! We have a working implementation of the impulse response function! We can now compute the IRFs for all posterior samples. To do this we need to vectorize the compute_irf_jit function.

# Get all posterior samples (flatten chain and draw dimensions)
phi_samples = jnp.array(idata["posterior"]["phi"].stack(sample=["chain", "draw"]))  # noqa PD013
# Transpose to get samples as first dimension
phi_samples = jnp.transpose(
    phi_samples, (3, 0, 1, 2)
)  # Shape: (n_samples, n_lags, n_vars, n_vars)


# Create a vmapped version that computes IRF for each posterior sample
# vmap over the first axis (samples)
compute_irf_vmap = vmap(compute_irf_jit, in_axes=(0, None, None))

# Compute IRFs for all posterior samples
irf_samples = compute_irf_vmap(phi_samples, n_irf_steps, 1.0)

# Create an xarray DataArray to store the IRFs
irf_samples_xr = xr.DataArray(
    data=jnp.expand_dims(irf_samples, axis=0),
    dims=("chain", "draw", "step", "var_1", "var_2"),
    coords={
        "chain": np.arange(1),
        "draw": np.arange(irf_samples.shape[0]),
        "step": np.arange(n_irf_steps),
        "var_1": data.columns,
        "var_2": data.columns,
    },
)

We can now plot the IRFs generated from the posterior samples and compare them against the IRFs from the statsmodels model.

var_irf = var_results.irf(n_irf_steps)

fig, axes = plt.subplots(
    nrows=len(data.columns),
    ncols=len(data.columns),
    figsize=(15, 12),
    sharex=True,
    sharey=False,
    layout="constrained",
)

for i, var_1 in enumerate(data.columns):
    for j, var_2 in enumerate(data.columns):
        ax = axes[i, j]
        for k, hdi_prob in enumerate([0.94, 0.5]):
            az.plot_hdi(
                range(n_irf_steps),
                irf_samples_xr.sel(var_1=var_1, var_2=var_2),
                hdi_prob=hdi_prob,
                color="C0",
                smooth=False,
                fill_kwargs={
                    "alpha": 0.2 + 0.3 * k,
                    "label": f"{hdi_prob: .0%} HDI",
                },
                ax=ax,
            )
        ax.plot(
            range(n_irf_steps),
            irf_samples_xr.sel(var_1=var_1, var_2=var_2)
            .mean(dim=("chain", "draw"))
            .to_numpy()
            .flatten(),
            color="C0",
            label="Posterior Mean",
        )
        ax.plot(var_irf.irfs[:, i, j], c="C1", label="statsmodels")
        ax.axhline(0, color="gray", linestyle="--")
        ax.legend()

fig.suptitle("Impulse Response Functions", fontsize=21, fontweight="bold", y=1.04);

We get the same results as the statsmodels model! 🎉

Observe that statsmodels generates similar plots (with similar uncertainty bounds).

fig = var_irf.plot(orth=False)

Orthogonalized vs Non-Orthogonalized IRFs

An important distinction in VAR analysis is between orthogonalized and non-orthogonalized impulse response functions. Let’s understand what this means and why it matters.

The Problem: Contemporaneous Correlation

In our VAR model, the error terms \(\varepsilon_t\) are typically correlated across variables. For example:

  • A shock to GDP and a shock to consumption might happen simultaneously
  • The error covariance matrix \(\Sigma = E[\varepsilon_t \varepsilon_t^{T}]\) is usually not diagonal

This creates a problem for interpretation: when we “shock” one variable, we’re implicitly shocking correlated variables too!

Non-Orthogonalized IRFs (What We Computed)

Our compute_irf function computes non-orthogonalized IRFs using the formula:

\[\Psi_i = \sum_{j=1}^{\min(i,p)} \Psi_{i-j} \Phi_j\]

Interpretation: These IRFs answer: “What happens if variable j experiences a one-unit shock, given the historical correlation structure of shocks?”

Advantages:

  • Simple to compute (just the VAR coefficients)
  • No additional identifying assumptions needed
  • Useful for forecasting and variance decomposition

Disadvantages:

  • Hard to interpret as “pure” shocks due to contemporaneous correlation
  • IRFs depend on variable ordering (in some contexts)

Orthogonalized IRFs (Cholesky Decomposition)

To get orthogonalized IRFs, we use the Cholesky decomposition of the error covariance matrix \(\Sigma = LL^{T}\), where \(L\) is lower triangular. Then:

\[\Psi_i^{\text{orth}} = \Psi_i \cdot L\]

In statsmodels, this is obtained with irf(orth=True) (the default).

Interpretation: These IRFs answer: “What happens if we apply a one-standard-deviation orthogonal shock to variable j, holding other orthogonal shocks constant?”

Advantages:

  • Shocks are uncorrelated by construction
  • Easier to interpret as “structural” shocks
  • Standard in macroeconomics literature

Disadvantages:

  • Depends on variable ordering (Cholesky decomposition is not unique)
  • Requires identifying assumptions (ordering = causal structure)
  • The first variable is assumed to affect all others contemporaneously, but not vice versa

Summary

In this notebook, we successfully implemented a Bayesian Vector Autoregressive (VAR) model using NumPyro and validated it against the established statsmodels implementation. Our key achievements include:

  1. Model Specification & Inference: We built a flexible \(\text{VAR}(p)\) model in NumPyro using jax.lax.scan for efficient time series dynamics, leveraging proper Bayesian priors (Normal for coefficients, LKJ for correlation structure) and MCMC sampling via NUTS. This can serve as a component to be combined with more complex models. For example, adding covariates and additional likelihoods.

  2. Forecasting: The model generates multi-step-ahead forecasts with full posterior uncertainty quantification. Our NumPyro forecasts closely match the statsmodels point predictions. We also show how to generate credible intervals for the forecasts.

  3. Impulse Response Functions: We implemented the recursive \(\text{MA}(\infty)\) representation algorithm to compute IRFs using JAX’s functional programming tools (lax.scan, vmap, jit). The resulting IRFs are identical to statsmodels outputs, validating our implementation.

Key Advantages of the NumPyro Approach:

  • Full Bayesian inference with uncertainty quantification for all parameters and predictions
  • Scalable computation through JAX’s JIT compilation and vectorization
  • Flexible specification allowing easy extensions (e.g., hierarchical priors, time-varying coefficients)

This implementation provides a solid foundation for more advanced Bayesian time series modeling in economic and financial applications among other domains.