In this post we want to revisit a simple bayesian inference example worked out in this blog post. This time we want to use TensorFlow Probability (TFP) instead of PyMC3.
References:
Statistical Rethinking is an amazing reference for Bayesian analysis. It also has a sequence of online lectures freely available on YouTube.
An introduction to probabilistic programming, now available in TensorFlow Probability
There are many examples on the TensorFlow’s GitHub repository. I am following the case study Bayesian Switchpoint Analysis for this example.
Prepare Notebook
import numpy as np
import scipy.stats as ss
import pandas as pd
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
# Data Viz.
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns
sns.set_style(
style='darkgrid',
rc={'axes.facecolor': '.9', 'grid.color': '.8'}
)
sns.set_palette(palette='deep')
sns_c = sns.color_palette(palette='deep')
%matplotlib inline
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 100
# Get TensorFlow version.
print(f'TnesorFlow version: {tf.__version__}')
print(f'TnesorFlow Probability version: {tfp.__version__}')
TnesorFlow version: 2.2.0
TnesorFlow Probability version: 0.10.0
Generate Data
We generate sample data from a Poisson distribution using TensorFlow Probability Distributions (see here for an introduction to this module). Recall,
\[ y \sim \text{Poisson}(\lambda) \quad \text{means} \quad P(y=k) = \frac{\lambda^k e^{-k}}{k!} \quad \text{for} \quad \lambda > 0, k\in \mathbb{N}_{\geq 0} \]
tf.random.set_seed(seed=42)
# Number of samples.
n = 100
# True rate parameter.
rate_true = 2.0
# Define Poisson distribution with the true rate parameter.
poisson_true = tfd.Poisson(rate=rate_true)
# Generate samples.
poisson_samples = poisson_true.sample(sample_shape=n)
poisson_samples
<tf.Tensor: shape=(100,), dtype=float32, numpy=
array([3., 3., 3., 1., 2., 1., 2., 0., 0., 2., 2., 1., 2., 2., 0., 3., 1.,
4., 1., 1., 2., 0., 3., 4., 1., 3., 1., 2., 2., 0., 0., 2., 1., 1.,
3., 5., 4., 3., 1., 4., 2., 2., 0., 2., 2., 4., 4., 2., 0., 4., 2.,
4., 1., 1., 2., 1., 1., 1., 4., 1., 3., 3., 1., 3., 0., 1., 1., 3.,
5., 1., 2., 2., 4., 0., 2., 3., 1., 3., 2., 4., 2., 2., 3., 2., 1.,
3., 1., 1., 0., 1., 2., 1., 2., 3., 4., 3., 2., 1., 0., 1.],
dtype=float32)>
Let us plot the sample distribution.
y_range, idy, c = tf.unique_with_counts(poisson_samples)
fig, ax = plt.subplots()
sns.barplot(x=y_range.numpy(), y=c.numpy(), color=sns_c[0], edgecolor=sns_c[0], ax=ax)
ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
ax.set(title=f'Poisson Samples Distribution (num_samples = {n}, rate_true = {rate_true})');
Set Prior Distribution
We use a gamma distribution as prior. The reason for it is because we know the Poisson is conjugate to the gamma distribution. Hence, it is possible to get an analytical solution.
Remark: In most real applications an analytical solution is hopeless.
\[ y \sim \text{Poisson}(\lambda) \quad \text{with} \quad \lambda \sim \Gamma(a, b) \]
# Define parameters for the prior distribution.
a = 4.5
b = 2
# Define prior distribution.
gamma_prior = tfd.Gamma(concentration=a, rate=b)
Let us generate some samples from this prior distribution:
# Generate samples.
gamma_prior_samples = gamma_prior.sample(sample_shape=1e4)
# Plot.
fig, ax = plt.subplots(figsize=(7, 5))
# Domain to plot.
x = np.linspace(start=0, stop=10, num=100)
# Plot samples distribution.
sns.distplot(
a=gamma_prior_samples,
color=sns_c[1],
kde=False,
norm_hist=True,
label='samples (1e4)',
ax=ax
)
# Plot density function of the gamma density.
sns.lineplot(
x=x,
y=ss.gamma.pdf(x, a=a, scale=1/b),
color=sns_c[1],
label='gamma_density',
ax=ax
)
# Some Stats.
sample_mean = tf.reduce_mean(gamma_prior_samples)
sample_median = tfp.stats.percentile(x=gamma_prior_samples, q=50)
ax.axvline(
x=sample_mean,
color=sns_c[1],
linestyle='--',
label=f'sample mean={sample_mean: 0.2f}'
)
ax.axvline(
x=sample_median,
color=sns_c[1],
linestyle=':',
label=f'sample median = {sample_median: 0.2f}'
)
ax.legend()
ax.set(title=f'Prior Gamma Distribution (a={a}, b={b})');
Prior Predictive Sampling
Before fitting the model to the data, let us sample form Poisson distributions based on these prior samples for the rate parameter lambda.
y_prior_pred = tfd.Poisson(rate=gamma_prior_samples).sample(1)
y_prior_pred = tf.reshape(y_prior_pred, [-1])
y_range_prior, idy_prior, c_prior = tf.unique_with_counts(y_prior_pred)
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
sns.barplot(
x=y_range.numpy(),
y=c.numpy(),
color=sns_c[0],
edgecolor=sns_c[0],
alpha=0.7,
label='Sample Data Distribution',
ax=ax2
)
sns.barplot(
x=y_range_prior.numpy(),
y=c_prior.numpy(),
color=sns_c[1],
edgecolor=sns_c[1],
label='Prior Predictive Sample Data Distribution',
alpha=0.7,
ax=ax1
)
ax1.set(title=f'Poisson Samples (Sample Data & Prior Predictive Samples)')
ax1.tick_params(axis='y', labelcolor=sns_c[1])
ax1.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
ax1.legend(loc='upper right')
ax2.grid(None)
ax2.legend(bbox_to_anchor=(0.84, 0.92))
ax2.tick_params(axis='y', labelcolor=sns_c[0])
Define Model
Next we are going to define our inference model. How do we do this with TFP?
TFP performs probabilistic inference by evaluating the model using an unnormalized joint log probability function. The arguments to this joint_log_prob are data and model state. The function returns the log of the joint probability that the parameterized model generated the observed data.1
# First we set the model specification.
def build_model(a=4.5, b=2):
# Prior Distribution.
rate = tfd.Gamma(concentration=a, rate=b)
# Likelihood: Independent samples of a Poisson distribution.
observations = lambda rate: tfd.Sample(
distribution=tfd.Poisson(rate=rate),
sample_shape=len(poisson_samples)
)
return tfd.JointDistributionNamed(dict(rate=rate, obs=observations))
# We set the joint-log-probability as the target variable we want to maximize.
def target_log_prob_fn(rate):
model = build_model()
return model.log_prob(rate=rate, obs=poisson_samples)
Bayesian Inference
Grid-Like-Approximation of the Mean
Let us start with a simple approach (compare with Statistical Rethinking, Chapter 3), on which we run a grid search to estimate the value of lambda (rate) which maximizes the joint-log-probability.
# Define rates range.
rates = np.linspace(start=0.01, stop=10.0, num=1000)
# Compute joint-log-probability.
model_log_probs = np.array([
target_log_prob_fn(rate).numpy()
for rate in rates
])
# Get rate which maximizes the log-probability of the model.
log_prob_maximizer = rates[np.argmax(model_log_probs)]
# Plot the results.
fig, ax = plt.subplots()
sns.lineplot(x=rates, y=model_log_probs, color=sns_c[0], label='model_log_prob', ax=ax)
ax.axvline(x=rate_true, linestyle='--', color=sns_c[3], label=f'rate_true = {rate_true: 0.2f}')
ax.axvline(x=log_prob_maximizer , linestyle='--', color=sns_c[1], label=f'log-prob-maximizer: {log_prob_maximizer: 0.2f}')
ax.legend(loc='upper right')
ax.set(title='Model Log Probability', xlabel='rate', ylabel='log probability');
We see that the values that maximizes the joint-log-distribution is ~ 1.95. Still, we are interested in the distribution of the parameter lambda, not a single point estimate).
Hamiltonian Monte Carlo Sampling
Next, we are going to use Hamiltonian Monte Carlo Sampling which is a very common way to run Bayesian inference. We are not going to go into the details of the methods, but rather on the direct usage of it using TFP. From TFP’s documentation:2
Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that takes a series of gradient-informed steps to produce a Metropolis proposal.
Here are some useful references on this topic:
- Statistical Rethinking, Chapter 9.
- The Geometry of Hamiltonian Monte Carlo
- A Conceptual Introduction to Hamiltonian Monte Carlo
For an in-depth description of the objects and methods please refer to the documentation.
Let us set up the Hamiltonian Monte Carlo algorithm.
# Size of each chain.
num_results = int(1e4)
# Burn-in steps.
num_burnin_steps = int(1e3)
# Hamiltonian Monte Carlo transition kernel.
# In TFP a TransitionKernel returns a new state given some old state.
hcm_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=1.0,
num_leapfrog_steps=3
)
# This adapts the inner kernel's step_size.
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel = hcm_kernel,
num_adaptation_steps=int(num_burnin_steps * 0.8)
)
# Run the chain (with burn-in).
@tf.function
def run_chain():
# Run the chain (with burn-in).
# Implements MCMC via repeated TransitionKernel steps.
samples, is_accepted = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=1.0,
kernel=adaptive_hmc,
trace_fn=lambda _, pkr: pkr.inner_results.is_accepted
)
return samples
Next, we run it and get samples from the posterior distribution for different chains.
# Set number of chains.
num_chains = 5
# Run sampling.
chains = [run_chain() for i in range(num_chains)]
Sampling Results
Let us visualize the samples and their distribution.
# We store the samples in a pandas dataframe.
chains_df = pd.DataFrame([t.numpy() for t in chains])
chains_df = chains_df.T.melt(var_name='chain_id', value_name='sample')
chains_df.head()
chain_id | sample | |
---|---|---|
0 | 0 | 1.867492 |
1 | 0 | 1.883758 |
2 | 0 | 1.917210 |
3 | 0 | 1.920594 |
4 | 0 | 1.920594 |
We plot the samples for each chain and indicate their mean and plus/minus 2 standard deviations from the mean.
fig, ax = plt.subplots(2, 1, figsize=(10, 8))
for i in range(5):
chain_samples = chains_df \
.query(f'chain_id == {i}') \
.reset_index(drop=True) \
['sample']
chain_samples_mean = chain_samples.mean()
chain_samples_std = chain_samples.std()
chain_samples_plus = chain_samples_mean + 2*chain_samples_std
chain_samples_minus = chain_samples_mean - 2*chain_samples_std
sns.distplot(a=chain_samples, color=sns_c[i], hist_kws={'alpha': 0.4}, label=f'chain_{i}', ax=ax[0])
ax[0].axvline(x=chain_samples_plus, linestyle='--', color=sns_c[i], label=f'chain_{i}_plus = {chain_samples_plus: 0.2f}')
ax[0].axvline(x=chain_samples_minus, linestyle='--', color=sns_c[i], label=f'chain_{i}_minus = {chain_samples_minus: 0.2f}')
ax[1].plot(chain_samples, c=sns_c[i], alpha=0.4)
ax[1].axhline(y=chain_samples_mean, linestyle='--', color=sns_c[i], label=f'chain_{i} mean = {chain_samples_mean: 0.2f}')
ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax[0].set(xlabel='rate', ylabel='')
ax[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax[1].set(xlabel='sample', ylabel='rate')
plt.suptitle('Hamiltonian Monte Carlo Chains', y=0.92);
The samples look convergent (more model diagnostics are out of the scope for this short-example, but they are key as part of the analysis).
Predictions
Posterior Distribution
Similarly as above, we no plot the distribution for all the chains together.
fig, ax = plt.subplots(2, 1)
chain_samples = chains_df['sample']
chain_samples_mean = chain_samples.mean()
chain_samples_std = chain_samples.std()
chain_samples_plus = chain_samples_mean + 2*chain_samples_std
chain_samples_minus = chain_samples_mean - 2*chain_samples_std
sns.distplot(a=chain_samples, color=sns_c[9], label=f'chains samples', ax=ax[0])
ax[0].axvline(x=chain_samples_plus, linestyle='--', color=sns_c[4], label=f'$\mu + 2\sigma$ = {chain_samples_plus: 0.2f}')
ax[0].axvline(x=chain_samples_minus, linestyle='--', color=sns_c[4], label=f'$\mu - 2\sigma$ = {chain_samples_minus: 0.2f}')
ax[1].plot(chain_samples, c=sns_c[9], alpha=0.7)
ax[1].axhline(y=chain_samples_mean, linestyle='--', color=sns_c[0], label=f'$\mu$ = {chain_samples_mean: 0.2f}')
ax[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.suptitle(f'Posterior Distribution (Rate)', y=0.92);
Posterior Predictive Sampling
Finally, let us sample from the posterior distribution of the rate parameter and then generate Poisson samples from them.
y_post_pred = tfd.Poisson(rate=chains_df['sample']).sample(1)
y_post_pred = tf.reshape(y_post_pred, [-1])
y_range_prior, idy_prior, c_prior = tf.unique_with_counts(y_post_pred)
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
sns.barplot(
x=y_range.numpy(),
y=c.numpy(),
color=sns_c[0],
edgecolor=sns_c[0],
alpha=0.7,
label='Sample Data Distribution',
ax=ax2
)
sns.barplot(
x=y_range_prior.numpy(),
y=c_prior.numpy(),
color=sns_c[2],
edgecolor=sns_c[2],
label='Posterior Predictive Sample Data Distribution',
alpha=0.7,
ax=ax1
)
ax1.set(title=f'Poisson Samples (Sample Data & Prior Predictive Samples)')
ax1.tick_params(axis='y', labelcolor=sns_c[1])
ax1.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
ax1.legend(loc='upper right')
ax2.grid(None)
ax2.legend(bbox_to_anchor=(0.8, 0.92))
ax2.tick_params(axis='y', labelcolor=sns_c[0])
The initial samples and the predicted samples distributions look very similar.