In this notebook we describe how to use blackjax
’s pathfinder
implementation to do inference with a numpyro
model.
I am simply putting some pieces together from the following resources (strongly recommended to read):
References:
- Blackjax docs: Use with Numpyro models
- Blackjax Sampling Book: Pathfinder
- Numpyro Issue #1485
- PyMC Experimental - Pathfinder
- Pathfinder: Parallel quasi-Newton variational inference
What and Why Pathfinder?
From the paper’s abstract:
- What?
We propose Pathfinder, a variational method for approximately sampling from differentiable log densities. Starting from a random initialization, Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the optimizer. Pathfinder returns draws from the approximation with the lowest estimated Kullback-Leibler (KL) divergence to the true posterior.
- Why?
Compared to ADVI and short dynamic HMC runs, Pathfinder requires one to two orders of magnitude fewer log density and gradient evaluations, with greater reductions for more challenging posteriors.
Prepare Notebook
import arviz as az
import blackjax
import jax
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoMultivariateNormal
from numpyro.infer.util import initialize_model, Predictive
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=42)
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generate Data
We generate some data from a simple linear regression model.
def generate_data(rng_key, a, b, sigma, n):
x = random.normal(rng_key, (n,))
rng_key, rng_subkey = random.split(rng_key)
epsilon = sigma * random.normal(rng_subkey, (n,))
y = a + b * x + epsilon
return x, y
# true parameters
a = 1.0
b = 2.0
sigma = 0.5
n = 100
# generate data
rng_key, rng_subkey = random.split(rng_key)
x, y = generate_data(rng_key, a, b, sigma, n)
# plot data
fig, ax = plt.subplots(figsize=(8, 7))
ax.plot(x, y, "o", c="C0", label="data")
ax.axline((0, a), slope=b, color="C1", label="true mean")
ax.legend(loc="upper left")
ax.set(xlabel="x", ylabel="y", title="Raw Data")
Model Specification
We define a simple linear regression model in numpyro
.
def model(x, y=None):
a = numpyro.sample("a", dist.Normal(0.0, 2.0))
b = numpyro.sample("b", dist.HalfNormal(2.0))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mean = numpyro.deterministic("mu", a + b * x)
with numpyro.plate("data", len(x)):
numpyro.sample("likelihood", dist.Normal(mean, sigma), obs=y)
numpyro.render_model(
model=model,
model_args=(x, y),
render_distributions=True,
render_params=True,
)
Pathfinder Sampler
The key function is initialize_model
from numpyro
. This allow us to compute the log-density, which is required by blackjax
’s pathfinder
implementation. In addition, we get a way to transform the unconstrained space (where the optimization happens) to the constrained space.
rng_key, rng_subkey = random.split(rng_key)
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
rng_subkey,
model,
model_args=(x, y),
dynamic_args=True, # <- this is important!
)
# get log-density from the potential function
def logdensity_fn(position):
func = potential_fn(x, y)
return -func(position)
# get initial position
initial_position = param_info.z
We can now use blackjax.vi.pathfinder.approximate
to run the variational inference algorithm.
# run pathfinder
rng_key, rng_subkey = random.split(rng_key)
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
rng_key=rng_subkey,
logdensity_fn=logdensity_fn,
initial_position=initial_position,
num_samples=10_000,
ftol=1e-4,
)
# sample from the posterior
rng_key, rng_subkey = random.split(rng_key)
posterior_samples, _ = blackjax.vi.pathfinder.sample(
rng_key=rng_subkey,
state=pathfinder_state,
num_samples=5_000,
)
# convert to arviz
idata_pathfinder = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_samples.items()
},
)
Visualize Results
az.summary(data=idata_pathfinder, round_to=3)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 1.043 | 0.051 | 0.951 | 1.139 | 0.001 | 0.001 | 5047.694 | 4780.373 | NaN |
b | 0.730 | 0.028 | 0.679 | 0.785 | 0.000 | 0.000 | 4998.218 | 4873.062 | NaN |
sigma | -0.657 | 0.069 | -0.785 | -0.527 | 0.001 | 0.001 | 5056.785 | 5101.670 | NaN |
axes = az.plot_trace(
data=idata_pathfinder,
compact=True,
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(
t="Pathfinder Trace - Transformed Space", fontsize=18, fontweight="bold"
)
Note that the value for a
is close to the true value of 1.0
. However, the values for b
and sigma
do not match the true values of 2.0
and 0.5
respectively. The reason is that we are using prior the prior distributions dist.HalfNormal
and dist.Exponential
for these parameters respectively. Since these distributions are positive, the sampler transform these parameters to the unconstrained space using a bijective transformation. To compare the results in the original space, we need to apply the inverse transformation.
Transform Samples
We can use the postprocess_fn
function returned by initialize_model
to transform the samples from the unconstrained space to the constrained space:
posterior_samples_transformed = jax.vmap(postprocess_fn(x, y))(posterior_samples)
rng_key, rng_subkey = random.split(rng_key)
posterior_predictive_samples_transformed = Predictive(
model=model, posterior_samples=posterior_samples_transformed
)(rng_subkey, x)
idata_pathfinder_transformed = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_samples_transformed.items()
},
posterior_predictive={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_predictive_samples_transformed.items()
},
)
axes = az.plot_trace(
data=idata_pathfinder_transformed,
var_names=["~mu"],
compact=True,
lines=[
("a", {}, a),
("b", {}, b),
("sigma", {}, sigma),
],
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(
t="Pathfinder Trace - Original Space", fontsize=18, fontweight="bold"
)
fig, ax = plt.subplots(figsize=(8, 7))
ax.plot(x, y, "o", c="C0", label="data")
ax.axline((0, a), slope=b, color="C1", label="true mean")
az.plot_hdi(
x=x,
y=idata_pathfinder_transformed["posterior_predictive"]["mu"],
color="C2",
fill_kwargs={"alpha": 0.5, "label": "mu posterior ($94\%$ HDI)"},
ax=ax,
)
az.plot_hdi(
x=x,
y=idata_pathfinder_transformed["posterior_predictive"]["likelihood"],
color="C2",
fill_kwargs={"alpha": 0.3, "label": "posterior predictive ($94\%$ HDI)"},
ax=ax,
)
ax.legend(loc="upper left")
ax.set(xlabel="x", ylabel="y", title="Pathfinder Posterior Predictive")
Appendix: SVI
Here we compare against the stochastic variational inference (SVI) algorithm implemented in numpyro
.
guide = AutoMultivariateNormal(model=model)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
rng_key, rng_subkey = random.split(key=rng_key)
n_samples = 1_000
svi_result = svi.run(rng_subkey, n_samples, x, y)
fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set_title("ELBO loss", fontsize=18, fontweight="bold")
params = svi_result.params
# get posterior samples (parameters)
predictive = Predictive(model=guide, params=params, num_samples=4_000)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_samples = predictive(rng_subkey, x, y)
# get posterior predictive (deterministics and likelihood)
predictive = Predictive(model=model, guide=guide, params=params, num_samples=4_000)
rng_key, rng_subkey = random.split(key=rng_key)
samples = predictive(rng_subkey, x, y)
idata_svi = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_samples.items()
},
)
az.summary(data=idata_svi, var_names=["a", "b", "sigma"], round_to=3)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 1.034 | 0.052 | 0.935 | 1.132 | 0.001 | 0.001 | 3158.147 | 3521.286 | NaN |
b | 2.062 | 0.063 | 1.940 | 2.178 | 0.001 | 0.001 | 4074.610 | 3881.514 | NaN |
sigma | 0.516 | 0.035 | 0.455 | 0.583 | 0.001 | 0.000 | 3630.886 | 3446.498 | NaN |
axes = az.plot_trace(
data=idata_svi,
var_names=["a", "b", "sigma"],
compact=True,
lines=[
("a", {}, a),
("b", {}, b),
("sigma", {}, sigma),
],
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(t="SVI Trace", fontsize=18, fontweight="bold")
axes = az.plot_forest(
data=[idata_pathfinder_transformed, idata_svi],
model_names=["Pathfinder", "SVI"],
var_names=["a", "b", "sigma"],
combined=True,
figsize=(9, 6),
)