---------------------------------------------------------------------- This is the API documentation for the numpyro_forecast library. ---------------------------------------------------------------------- ## Forecasters High-level interfaces for fitting and forecasting. Forecaster(rng_key: jax.Array, model: collections.abc.Callable[..., None], data: jax.Array, covariates: jax.Array, *, guide: numpyro.infer.autoguide.AutoGuide | None = None, optim: numpyro.optim._NumPyroOptim | None = None, num_steps: int = 1001, num_particles: int = 1, progress_bar: bool = False) -> None Fit a forecasting model with stochastic variational inference. Parameters ---------- rng_key PRNG key for inference. model The forecasting model to fit (OOP instance or functional model). data In-sample data with time at axis ``-2``. covariates Covariates with time at axis ``-2`` and the same duration as ``data``. guide Variational guide; defaults to ``AutoNormal(model)``. optim NumPyro optimizer; defaults to ``Adam(0.01)``. num_steps Number of SVI steps. num_particles Number of ELBO particles. progress_bar Whether to display the SVI progress bar. HMCForecaster(rng_key: jax.Array, model: collections.abc.Callable[..., None], data: jax.Array, covariates: jax.Array, *, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, progress_bar: bool = False) -> None Fit a forecasting model with NUTS (Hamiltonian Monte Carlo). Parameters ---------- rng_key PRNG key for inference. model The forecasting model to fit (OOP instance or functional model). data In-sample data with time at axis ``-2``. covariates Covariates with time at axis ``-2`` and the same duration as ``data``. num_warmup Number of warmup steps. num_samples Number of posterior samples. num_chains Number of MCMC chains. progress_bar Whether to display the MCMC progress bar. ## Models Building forecasting models (object-oriented and functional). ForecastingModel() -> None Abstract base class for forecasting models. Subclasses implement :meth:`model`, which must call :meth:`predict` exactly once. The instance itself is the (pure) NumPyro model function with signature ``model_instance(covariates, data=None)``: the forecast horizon is inferred from the shapes (``future = covariates.shape[-2] - data.shape[-2]``). This is the object-oriented façade over the functional API: :meth:`time_series` and :meth:`predict` delegate to the free functions in :mod:`numpyro_forecast.functional`, passing the current :class:`~numpyro_forecast.functional.Horizon`. forecasting_model(model_fn: collections.abc.Callable[[numpyro_forecast.functional.Horizon, jax.Array], None]) -> collections.abc.Callable[..., None] Build a NumPyro model from a functional model body. The functional analogue of subclassing :class:`~numpyro_forecast.forecaster.ForecastingModel`. ``model_fn`` is a pure function ``(Horizon, covariates) -> None`` that calls :func:`time_series` and :func:`predict`; this wraps it into the standard NumPyro model callable ``(covariates, data=None)``, deriving the :class:`Horizon` from the shapes. Parameters ---------- model_fn The model body. It receives the per-call :class:`Horizon` (use ``h.zero_data`` for the Pyro-style ``zero_data``) and the covariates with time at axis ``-2``. Returns ------- ForecastModel A callable ``(covariates, data=None) -> None`` usable with ``SVI``, ``MCMC``, ``Predictive``, :func:`fit_svi`, :func:`fit_mcmc`, and the OOP forecaster classes. ## Functional core Pure functional primitives for the train/forecast split. Horizon(data: jax.Array | None, t_obs: int, future: int, duration: int) -> None The train/forecast split for a single model call. Replaces the mutable ``self._*`` state of the OOP base class with an immutable value derived from the covariate and data shapes via :meth:`from_data`. The functional primitives (:func:`time_series`, :func:`predict`) take it as their first argument. Attributes ---------- data Observed in-sample data with time at axis ``-2`` (``None`` during pure prior sampling). t_obs Number of observed (in-sample) time steps ``t``. future Number of forecast time steps ``f`` (``0`` while training). duration Total horizon length ``t + future`` (in time steps). time_series(h: numpyro_forecast.functional.Horizon, name: str, dist_fn: collections.abc.Callable[[], numpyro.distributions.distribution.Distribution], *, reparam: numpyro.infer.reparam.Reparam | None = None) -> jax.Array Sample a time-varying latent over the full horizon. The in-sample portion is sampled under ``plate("time", t)`` with the fixed site ``name``; when forecasting, the horizon portion is sampled under a separate site ``f"{name}_future"`` and concatenated. The separate site keeps the guide shape fixed and lets ``Predictive`` draw the forecast suffix from the prior. Parameters ---------- h The horizon for the current model call (see :class:`Horizon`). name Base sample-site name for the in-sample latent. dist_fn Zero-argument callable returning the per-step prior distribution. reparam Optional reparameterization (e.g. ``LocScaleReparam``) applied to both the in-sample and forecast sites. Returns ------- Array The latent over the full horizon with time at axis ``-2``. predict(h: numpyro_forecast.functional.Horizon, noise_dist: numpyro.distributions.distribution.Distribution, prediction: jax.Array) -> None Register the observation/forecast sites for the model. ``noise_dist`` is a zero-centered observation noise distribution and ``prediction`` the deterministic mean over the full horizon. While training the residual is observed; while forecasting the in-sample prefix is observed and the forecast suffix is sampled and exposed as the ``"forecast"`` deterministic site. Parameters ---------- h The horizon for the current model call (see :class:`Horizon`). noise_dist Zero-centered observation noise (e.g. ``Normal(0, sigma)``). prediction Deterministic mean with time at axis ``-2``, shape ``(*batch, duration, obs)``. Raises ------ RuntimeError If forecasting (``future > 0``) but no observed data is available. fit_svi(rng_key: jax.Array, model: collections.abc.Callable[..., None], data: jax.Array, covariates: jax.Array, *, guide: numpyro.infer.autoguide.AutoGuide | None = None, optim: numpyro.optim._NumPyroOptim | None = None, num_steps: int = 1001, num_particles: int = 1, progress_bar: bool = False) -> numpyro_forecast.functional.SVIFit Fit a forecasting model with stochastic variational inference. Parameters ---------- rng_key PRNG key for inference. model The forecasting model callable (OOP instance or functional model). data In-sample data with time at axis ``-2``. covariates Covariates with time at axis ``-2`` and the same duration as ``data``. guide Variational guide; defaults to ``AutoNormal(model)``. optim NumPyro optimizer; defaults to ``Adam(0.01)``. num_steps Number of SVI steps. num_particles Number of ELBO particles. progress_bar Whether to display the SVI progress bar. Returns ------- SVIFit The fitted guide, variational parameters, and loss history. Raises ------ ValueError If ``data`` and ``covariates`` have different durations. draw_posterior(rng_key: jax.Array, fit: object, num_samples: int) -> dict[str, jax.Array] Draw ``num_samples`` posterior samples of the latent sites from a fit. Dispatches on the fit type (e.g. :class:`SVIFit`, :class:`MCMCFit`). The returned dict has the sample axis leading and is ready to pass to :func:`forecast` or NumPyro's ``Predictive``. Parameters ---------- rng_key PRNG key. fit A fit result produced by :func:`fit_svi` or :func:`fit_mcmc`. num_samples Number of posterior draws. Returns ------- dict[str, Array] Posterior samples of the latent sites, sample axis leading. Raises ------ NotImplementedError If ``fit`` is of an unsupported type. Notes ----- For an :class:`MCMCFit`, when ``num_samples`` does not exceed the number of draws in the chain the draws are thinned on an evenly spaced grid (no duplicates); only when more samples are requested than the chain holds are they resampled with replacement. For an :class:`SVIFit` the draws are sampled afresh from the fitted guide. fit_mcmc(rng_key: jax.Array, model: collections.abc.Callable[..., None], data: jax.Array, covariates: jax.Array, *, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, progress_bar: bool = False) -> numpyro_forecast.functional.MCMCFit Fit a forecasting model with NUTS (Hamiltonian Monte Carlo). Parameters ---------- rng_key PRNG key for inference. model The forecasting model callable (OOP instance or functional model). data In-sample data with time at axis ``-2``. covariates Covariates with time at axis ``-2`` and the same duration as ``data``. num_warmup Number of warmup steps. num_samples Number of posterior samples. num_chains Number of MCMC chains. progress_bar Whether to display the MCMC progress bar. Returns ------- MCMCFit The posterior samples. Raises ------ ValueError If ``data`` and ``covariates`` have different durations. forecast(rng_key: jax.Array, model: collections.abc.Callable[..., None], posterior: dict[str, jax.Array], data: jax.Array, covariates: jax.Array, *, batch_size: int | None = None) -> jaxtyping.Float[Array, 'sample *batch future obs'] Sample forecasts for the steps in ``[t, duration)`` from a posterior. Runs ``Predictive`` with full-horizon ``covariates`` and the in-sample ``data``: the in-sample latent sites are drawn from ``posterior`` while the ``_future`` suffix is drawn from the prior, and the ``"forecast"`` site is returned. The number of forecast samples equals the leading (sample) axis of ``posterior`` (see :func:`draw_posterior`). Parameters ---------- rng_key PRNG key. model The forecasting model callable (the same one that produced ``posterior``). posterior Posterior samples of the latent sites, sample axis leading. data Observed data with time at axis ``-2`` and length ``t``. covariates Covariates with time at axis ``-2`` and length ``duration > t``. batch_size Optional chunk size for sampling (caps peak memory). Returns ------- Float[Array, " sample *batch future obs"] Forecast samples over the ``future = duration - t`` horizon. Raises ------ ValueError If ``covariates`` does not extend beyond ``data`` along the time axis. predict_in_sample(rng_key: jax.Array, model: collections.abc.Callable[..., None], posterior: dict[str, jax.Array], covariates: jax.Array, *, batch_size: int | None = None) -> jaxtyping.Float[Array, 'sample *batch time obs'] Sample the in-sample posterior predictive of the ``obs`` site. Runs ``Predictive`` with the in-sample ``covariates`` and the supplied posterior latent draws. Unlike :func:`forecast` there is no forecast horizon: ``covariates`` span only the observed window, so the model's ``obs`` site is sampled at every step. The number of predictive samples equals the leading (sample) axis of ``posterior`` (see :func:`draw_posterior`). Parameters ---------- rng_key PRNG key. model The forecasting model callable (the same one that produced ``posterior``). posterior Posterior samples of the latent sites, sample axis leading. covariates Covariates with time at axis ``-2`` spanning the observed window. Its time length must match the data the ``posterior`` was fit on, since the in-sample latent sites are sized to that window. batch_size Optional chunk size for sampling (caps peak memory). Returns ------- Float[Array, " sample *batch time obs"] In-sample posterior-predictive draws of the ``obs`` site. SVIFit(guide: numpyro.infer.autoguide.AutoGuide, params: dict[str, jax.Array], losses: jax.Array) -> None The result of fitting a forecasting model with SVI. Attributes ---------- guide The fitted variational guide. params The learned variational parameters. losses The ELBO loss per SVI step (shape ``(num_steps,)``). MCMCFit(samples: dict[str, jax.Array]) -> None The result of fitting a forecasting model with MCMC (NUTS). Attributes ---------- samples The posterior samples of the latent sites, sample axis leading. ## Backtesting & evaluation Rolling-window backtesting and forecast metrics. backtest(rng_key: jax.Array, data: jax.Array, covariates: jax.Array, model_fn: collections.abc.Callable[[], collections.abc.Callable[..., None]], *, forecaster_fn: collections.abc.Callable[..., '_BaseForecaster'] = , metrics: collections.abc.Mapping[str, collections.abc.Callable[[jax.Array, jax.Array], float]] | None = None, transform: collections.abc.Callable[[jax.Array, jax.Array], tuple[jax.Array, jax.Array]] | None = None, train_window: int | None = None, min_train_window: int = 1, test_window: int | None = None, min_test_window: int = 1, stride: int = 1, num_samples: int = 100, batch_size: int | None = None, forecaster_options: collections.abc.Mapping[str, Any] | collections.abc.Callable[..., collections.abc.Mapping[str, Any]] | None = None, eval_train: bool = False, keep_predictions: bool = False) -> list[numpyro_forecast.evaluate.BacktestResult] Backtest a forecasting model on a moving window of ``(train, test)`` data. Parameters ---------- rng_key Base PRNG key (used for every window, matching Pyro). data Dataset with time at axis ``-2``. covariates Covariates with time at axis ``-2`` (same duration as ``data``). model_fn Factory returning a fresh :class:`ForecastingModel` per window. forecaster_fn Factory returning a fitted forecaster (defaults to :class:`Forecaster`). metrics Mapping of metric name to function; defaults to :data:`DEFAULT_METRICS`. Each function takes ``(pred, truth)`` and returns a float; bind any metric-specific parameters with :func:`functools.partial`, e.g. ``{**DEFAULT_METRICS, "coverage": partial(eval_coverage, alpha=0.8)}``. transform Optional ``(pred, truth) -> (pred, truth)`` applied before metrics. train_window Training window size; if ``None`` the window expands from the start. min_train_window Minimum training window size when ``train_window`` is ``None``. test_window Test window size; if ``None`` forecasts to the end of the data. min_test_window Minimum test window size when ``test_window`` is ``None``. stride Step between successive train/test splits. num_samples Number of forecast samples per window. batch_size Optional forecast-sampling chunk size. forecaster_options Options dict passed to ``forecaster_fn``, or a callable ``(t0, t1, t2) -> dict`` returning per-window options. eval_train If ``True``, also score the in-sample posterior predictive over each training window with the same ``metrics`` and store them in ``BacktestResult.train_metrics``. Requires a forecaster exposing ``predict_in_sample`` (the built-in :class:`Forecaster` and :class:`HMCForecaster` do). keep_predictions If ``True``, store each window's out-of-sample forecast samples (after ``transform``) on ``BacktestResult.prediction``. Defaults to ``False`` to avoid retaining large Monte Carlo arrays. Returns ------- list[BacktestResult] One result per backtest window. BacktestResult(t0: int, t1: int, t2: int, num_samples: int, train_walltime: float, test_walltime: float, metrics: dict[str, float], params: dict[str, float] = , train_metrics: dict[str, float] = , prediction: jax.Array | None = None) -> None Per-window result of a :func:`backtest` run. Attributes ---------- t0, t1, t2 Train-begin, train/test split, and test-end time indices. num_samples Number of forecast samples drawn. train_walltime, test_walltime Wall-clock seconds for fitting and forecasting. metrics Mapping of metric name to value for the window. params Mapping of scalar parameter name to value (when available). train_metrics Mapping of metric name to in-sample value for the window. Empty unless ``backtest`` was called with ``eval_train=True``. prediction Out-of-sample forecast samples for the window (sample axis leading), or ``None`` unless ``backtest`` was called with ``keep_predictions=True``. evaluate_forecast(pred: jaxtyping.Float[Array, 'sample *batch'], truth: jaxtyping.Float[Array, '*batch'], *, metrics: collections.abc.Mapping[str, collections.abc.Callable[[jax.Array, jax.Array], float]] | None = None) -> dict[str, float] Evaluate forecast samples against ground truth for several metrics at once. A one-call convenience that applies each metric in ``metrics`` to the same forecast samples and ground truth. It is the one-shot counterpart to :func:`backtest` and is also used internally by :func:`backtest` to score each rolling window. Metric-specific parameters live with the metric in the ``metrics`` mapping, not on this function. To tune a metric, bind its keyword with :func:`functools.partial`; for example, to score coverage at the 80% level:: from functools import partial metrics = {**DEFAULT_METRICS, "coverage": partial(eval_coverage, alpha=0.8)} evaluate_forecast(pred, truth, metrics=metrics) Parameters ---------- pred Forecast samples with the sample axis first, shape ``(sample, *batch)``. truth Ground-truth values with shape ``(*batch)``. metrics Mapping of metric name to function; when ``None`` defaults to :data:`DEFAULT_METRICS` (``mae``, ``rmse``, ``crps`` and ``coverage``). Each function takes ``(pred, truth)`` and returns a float; bind any extra parameters with :func:`functools.partial` (see above). Returns ------- dict[str, float] Each metric name mapped to its value. eval_crps(pred: jax.Array, truth: jax.Array) -> float Empirical CRPS averaged over all data elements. Parameters ---------- pred Forecast samples with the sample axis first. truth Ground-truth values (matching ``pred`` without the sample axis). Returns ------- float The mean empirical CRPS. eval_mae(pred: jax.Array, truth: jax.Array) -> float Mean absolute error using the forecast sample median as point estimate. Parameters ---------- pred Forecast samples with the sample axis first. truth Ground-truth values (matching ``pred`` without the sample axis). Returns ------- float The mean absolute error. eval_rmse(pred: jax.Array, truth: jax.Array) -> float Root mean squared error using the forecast sample mean as point estimate. Parameters ---------- pred Forecast samples with the sample axis first. truth Ground-truth values (matching ``pred`` without the sample axis). Returns ------- float The root mean squared error. eval_coverage(pred: jax.Array, truth: jax.Array, *, alpha: float = 0.9) -> float Empirical coverage of the central ``alpha`` prediction interval. The central ``alpha`` interval is bounded by the ``(1 - alpha) / 2`` and ``1 - (1 - alpha) / 2`` quantiles of the forecast samples; the metric is the fraction of ground-truth values that fall inside it. A well-calibrated forecast has coverage close to ``alpha``. Parameters ---------- pred Forecast samples with the sample axis first. truth Ground-truth values (matching ``pred`` without the sample axis). alpha Nominal interval level in ``(0, 1)``; when omitted, uses the module default ``_DEFAULT_COVERAGE_ALPHA``. Returns ------- float The fraction of ground-truth values inside the central ``alpha`` interval. crps_empirical(pred: jaxtyping.Float[Array, 'sample *batch'], truth: jaxtyping.Float[Array, '*batch']) -> jaxtyping.Float[Array, '*batch'] Compute the empirical Continuous Ranked Probability Score (CRPS). The CRPS generalises the mean absolute error to probabilistic forecasts and is computed elementwise as .. math:: \mathrm{CRPS}(F, y) = \mathbb{E}|X - y| - \tfrac{1}{2}\,\mathbb{E}|X - X'|, where :math:`X, X'` are independent draws from the forecast distribution :math:`F`. The expectations are estimated from the forecast ``sample`` axis using the sorted-sample :math:`O(n \log n)` identity. Parameters ---------- pred Forecast samples with the sample axis first, shape ``(sample, *batch)``. truth Ground-truth values with shape ``(*batch)`` (broadcastable to ``pred``). Returns ------- Float[Array, "*batch"] Elementwise CRPS, one value per ``batch`` location. References ---------- Tilmann Gneiting, Adrian E. Raftery (2007). "Strictly Proper Scoring Rules, Prediction, and Estimation". *Journal of the American Statistical Association*. ## Utilities Array helpers and feature builders. fourier_features(duration: int, period: float, num_terms: int) -> jaxtyping.Float[Array, 'duration two_num_terms'] Build a Fourier seasonality design matrix. Parameters ---------- duration Number of time steps. period Seasonal period (in time steps). num_terms Number of harmonics; the output has ``2 * num_terms`` columns (sine then cosine). Returns ------- Float[Array, "duration two_num_terms"] The design matrix of shape ``(duration, 2 * num_terms)``. periodic_repeat(x: jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex, duration: int, *, axis: int = -1) -> jax.Array Tile a seasonal pattern to cover ``duration`` time steps. Parameters ---------- x Seasonal pattern; the repeated axis has length equal to the period. Accepts any array-like (e.g. a raw ``numpyro.sample`` draw). duration Target length along ``axis``. axis Axis to repeat along (defaults to ``-1``). Returns ------- Array ``x`` periodically repeated to length ``duration`` along ``axis``. zero_data_like(data: jax.Array, covariates: jax.Array) -> jax.Array Return zeros shaped like ``data`` but extended to the covariate duration. Mirrors Pyro's ``zero_data``: it exposes the shape/dtype of the data over the full forecast horizon without leaking observed values into the model. The functional API exposes the equivalent value as :attr:`numpyro_forecast.functional.Horizon.zero_data`. Parameters ---------- data Observed data with time at axis ``-2``, shape ``(*batch, t, obs)``. covariates Covariates with time at axis ``-2``, shape ``(*batch, duration, cov)``. Returns ------- Array Zeros of shape ``(*batch, duration, obs)``. concat_future(prefix: jax.Array, suffix: jax.Array, *, axis: int = -2) -> jax.Array Concatenate in-sample and forecast-horizon arrays along the time axis. Parameters ---------- prefix In-sample array. suffix Forecast-horizon array (same shape as ``prefix`` except along ``axis``). axis Time axis to concatenate along (defaults to ``-2``). Returns ------- Array The concatenation of ``prefix`` and ``suffix`` along ``axis``. shift_loc(noise_dist: numpyro.distributions.distribution.Distribution, loc: jax.Array) -> numpyro.distributions.distribution.Distribution Re-center a zero-centered noise distribution at ``loc``. This converts Pyro's ``obs = data - prediction`` idiom into an additive shift of the observation distribution's location. Parameters ---------- noise_dist A zero-centered location-family distribution. loc The deterministic mean to add to the distribution's location. Returns ------- dist.Distribution A distribution centered at ``loc``. Raises ------ NotImplementedError If ``noise_dist`` is of an unsupported type. slice_time(noise_dist: numpyro.distributions.distribution.Distribution, index: slice) -> numpyro.distributions.distribution.Distribution Slice an elementwise distribution along the time axis ``-2``. The default implementation handles distributions with empty ``event_shape`` whose ``batch_shape`` ends with ``(time, obs)`` (e.g. ``Normal``, ``StudentT``) by slicing each broadcast parameter. Parameters ---------- noise_dist The distribution to slice. index A ``slice`` applied to the time axis ``-2`` of the batch shape. Returns ------- dist.Distribution The same distribution family restricted to the selected time steps. Raises ------ NotImplementedError If the distribution has a non-empty event shape. prefix_condition(noise_dist: numpyro.distributions.distribution.Distribution, data: jax.Array) -> numpyro.distributions.distribution.Distribution Condition a ``(t+f)``-length distribution on a ``t``-length data prefix. For independent-over-time noise (the default) the conditional reduces to the forecast-horizon marginal, i.e. a time slice ``[t:]``. Only independent families are supported today; correlated families (e.g. ``MultivariateNormal``) would need a registered dispatch implementing a genuine Gaussian conditional, which is not yet provided. Parameters ---------- noise_dist The observation distribution over the full horizon ``(*batch, t+f, obs)``. data The observed prefix with shape ``(*batch, t, obs)``. Returns ------- dist.Distribution The forecast-horizon distribution over ``(*batch, f, obs)``. ## Datasets Example datasets used in the tutorials. load_bart_weekly() -> jaxtyping.Float[Array, 'weeks 1'] Load total weekly BART ridership (log scale) for the univariate example. Hourly counts are summed over all origin-destination pairs, aggregated into non-overlapping weeks, and log-transformed. Returns ------- Float[Array, " weeks 1"] Log weekly totals with time at axis ``-2`` and a single observation dim. load_bart_hierarchical(train_days: int = 90, test_weeks: int = 2) -> tuple[jaxtyping.Float[Array, 'origin time destin'], int, list[str]] Load the windowed hierarchical BART panel for the hierarchical example. The counts are ``log1p``-transformed and transposed to the ``(origin, time, destin)`` convention, then restricted to a ``train_days`` training window followed by a ``test_weeks`` test window. Parameters ---------- train_days Number of training days (24 hours each). test_weeks Number of test weeks (``24 * 7`` hours each). Returns ------- y : Float[Array, " origin time destin"] Log counts over the train+test window with time at axis ``-2``. split : int Index along the time axis separating train from test. stations : list[str] Station names. Raises ------ ValueError If the requested ``train_days`` + ``test_weeks`` window exceeds the available history (which would otherwise wrap a negative slice index). load_victoria_electricity() -> tuple[jaxtyping.Float[Array, 'time 1'], jaxtyping.Float[Array, 'time']] Load hourly Victoria (Australia) electricity demand and temperature. The series covers the first eight weeks of 2014, sampled hourly, from the Victoria electricity demand dataset used in the TensorFlow Probability structural-time-series case study and in Hyndman and Athanasopoulos' *Forecasting: Principles and Practice*. The original half-hourly data is downsampled to hourly by taking every other step. The values are bundled as a small CSV next to this module. Returns ------- demand : Float[Array, " time 1"] Hourly electricity demand (GW) with time at axis ``-2`` and a single observation dimension. temperature : Float[Array, " time"] Hourly temperature (degrees Celsius), aligned with ``demand``. bart_available() -> bool Return whether the BART dataset can be loaded (download succeeds). Returns ------- bool ``True`` if :func:`load_bart_od` loads without error.