crps_train = eval_crps(in_sample_pp, train_data)
crps_test = eval_crps(forecast_samples, test_data)
idata_in_sample = az.from_dict(
{
"posterior_predictive": {"obs": np.asarray(in_sample_pp[..., 0])[None]},
"observed_data": {"obs": np.asarray(y_train)},
"constant_data": {"time": np.asarray(t_train)},
},
coords={"t": np.asarray(t_train)},
dims={"obs": ["t"], "time": ["t"]},
)
idata_forecast = az.from_dict(
{
"posterior_predictive": {"obs": np.asarray(forecast_samples[..., 0])[None]},
"observed_data": {"obs": np.asarray(y_test)},
"constant_data": {"time": np.asarray(t_test)},
},
coords={"t": np.asarray(t_test)},
dims={"obs": ["t"], "time": ["t"]},
)
pc = az.plot_lm(
idata_in_sample,
y="obs",
x="time",
ci_kind="hdi",
ci_prob=(0.5, 0.94),
smooth=False,
visuals={"ci_band": {"color": "C0"}, "observed_scatter": False, "pe_line": False},
figure_kwargs={"figsize": (12, 7)},
)
in_sample_bands = pc.viz["ci_band"]["time"]
band_in_94 = in_sample_bands.sel(prob=0.94).item()
band_in_50 = in_sample_bands.sel(prob=0.5).item()
az.plot_lm(
idata_forecast,
y="obs",
x="time",
plot_collection=pc,
ci_kind="hdi",
ci_prob=(0.5, 0.94),
smooth=False,
visuals={"ci_band": {"color": "C1"}, "observed_scatter": False, "pe_line": False},
)
forecast_bands = pc.viz["ci_band"]["time"]
band_fc_94 = forecast_bands.sel(prob=0.94).item()
band_fc_50 = forecast_bands.sel(prob=0.5).item()
ax = pc.viz["figure"].item().axes[0]
band_in_94.set_label(r"in-sample $94\%$ HDI")
band_in_50.set_label(r"in-sample $50\%$ HDI")
band_fc_94.set_label(r"forecast $94\%$ HDI")
band_fc_50.set_label(r"forecast $50\%$ HDI")
(observed_line,) = ax.plot(np.asarray(t), np.asarray(y), color="black", lw=1, label="observed")
split_line = ax.axvline(float(t_test[0]), color="gray", linestyle="--", label="train/test split")
ax.legend(
handles=[band_in_94, band_in_50, band_fc_94, band_fc_50, observed_line, split_line],
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
ncol=3,
)
ax.set(
title=f"Exponential smoothing forecast (train CRPS: {crps_train:.3f}, test CRPS: {crps_test:.3f})",
xlabel="time",
ylabel="y",
)
plt.show()