In this notebook we want to go take a look into the distributions module of TensorFlow probability. The aim is to understand the fundamentals and then explore further this probabilistic programming framework. Here you can find an overview of TensorFlow Probability. We will concentrate on the first part of Layer 1: Statistical Building Blocks. As you could see from the distributions module documentation, there are many classes of distributions. We will explore a small sample of them in order to get an overall overview. I find the documentation itself a great place to start. In addition, there is a sample of notebooks with concrete examples on the GitHub repository. In particular, I will follow some of cases presented on the A_Tour_of_TensorFlow_Probability notebook, expand on some details and add some other additional examples.
Prepare Notebook
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
# Data Viz.
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style(
style='darkgrid',
rc={'axes.facecolor': '.9', 'grid.color': '.8'}
)
sns.set_palette(palette='deep')
sns_c = sns.color_palette(palette='deep')
%matplotlib inline
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['figure.dpi'] = 100
# Get TensorFlow version.
print(f'TnesorFlow version: {tf.__version__}')
print(f'TnesorFlow Probability version: {tfp.__version__}')
TnesorFlow version: 2.2.0
TnesorFlow Probability version: 0.10.0
Distributions
Let us consider a couple of well-know distributions to begin:
normal = tfd.Normal(loc=0.0, scale=1.0)
gamma = tfd.Gamma(concentration=5.0, rate=1.0)
poisson = tfd.Poisson(rate=2.0)
laplace = tfd.Laplace(loc=0.0, scale=1.0)
Let us sample from each of them and visualize their distributions.
n_samples = 800
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()
sns.distplot(a=normal.sample(n_samples), color=sns_c[0], rug=True, ax=axes[0])
axes[0].set(title=f'Normal Distribution')
sns.distplot(a=gamma.sample(n_samples), color=sns_c[1], rug=True, ax=axes[1])
axes[1].set(title=f'Gamma Distribution');
sns.distplot(a=poisson.sample(n_samples), color=sns_c[2], kde=False, rug=True, ax=axes[2])
axes[2].set(title='Poisson Distribution');
sns.distplot(a=laplace.sample(n_samples), color=sns_c[3], rug=True, ax=axes[3])
axes[3].set(title='Laplace Distribution')
plt.suptitle(f'Distribution Samples ({n_samples})', y=0.95);
We treat distributions as tensors, which can have many dimensions. In particular,
Batch shape denotes a collection of Distributions with distinct parameters.
Event shape denotes the shape of samples from the Distribution.
As a convention, batch shapes are on the “left” and event shapes on the “right”1.
We can take samples based on a tensor of sizes:
normal_samples = normal.sample([n_samples, n_samples])
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes = axes.flatten()
for i in range(2):
sns.distplot(a=normal_samples[i], color=sns_c[0], rug=True, ax=axes[i])
axes[i].set(title=f'Normal Distribution (Iter {i})')
plt.suptitle(f'Distribution Samples ({n_samples})', y=0.97);
We can compute common stats of the distribution samples:
x = tf.linspace(start=-5.0, stop=5.0, num=100)
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
sns.lineplot(x=x, y=normal.cdf(x), color=sns_c[3], label='cdf', ax=ax1)
sns.distplot(a=normal.sample(n_samples), color=sns_c[0], label='samples', rug=True, ax=ax2)
ax1.axvline(x=0.0, color=sns_c[2], linestyle='--', label=r'$\mu=0$')
q_list = [0.05, 0.95]
quantiles = normal.quantile(q_list).numpy()
for i, q in zip(q_list , quantiles):
ax1.axvline(x=q, color=sns_c[1], linestyle='--', label=f'quantile {i}')
ax1.tick_params(axis='y', labelcolor=sns_c[0])
ax2.grid(None)
ax2.tick_params(axis='y', labelcolor=sns_c[3])
ax1.legend(loc='upper left')
ax2.legend(loc='center left')
ax1.set(title='Normal Distribution');
Next, let us consider a sequence of normal distributions with fixed standard deviation and increasing mean.
loc_list = np.linspace(start=0.0, stop=8.0, num=5)
normals = tfd.Normal(loc=loc_list, scale=1.0)
normal_samples = normals.sample(n_samples)
normal_samples.shape
TensorShape([800, 5])
fig, ax = plt.subplots()
for i in range(normals.batch_shape[0]):
sns.distplot(a=normal_samples[:, i], ax=ax)
ax.set(title='Batch Samples Normal Distribution');
- Entropy
Let us compute the (Shannon) entropy of each distribution.
fig, ax = plt.subplots(figsize=(8, 6))
sns.barplot(x=list(range(normals.batch_shape[0])), y=normals.entropy(), ax=ax)
ax.set(title='Entropy', xlabel='batch');
As expected, these values remain do not depend on the mean. Indeed, the entropy for a normal distribution just depends on the standard deviation. The explicit value can be computed as:
(1/2)*np.log(2*np.pi*np.exp(1)*1.0)
1.4189385332046727
- Cross Entropy
We now compute the cross-entropy from the first normal distribution to the rest. We display the sample distributions as violin plots.
fig, ax1 = plt.subplots()
sns.violinplot(data=normal_samples, ax=ax1)
ax2 = ax1.twinx()
sns.lineplot(
x=range(normals.batch_shape[0]),
y=normals[0].cross_entropy(normals),
color='black',
alpha=0.5,
ax=ax2
)
sns.scatterplot(
x=range(normals.batch_shape[0]),
y=normals[0].cross_entropy(normals),
s=80,
color='black',
label='cross entropy (first batch to others)',
ax=ax2
)
ax2.grid(None)
ax2.legend(loc='lower right')
ax2.tick_params(axis='y', labelcolor='black')
ax2.set(ylabel='cross entropy')
ax1.set(title='Cross-Entropy from First Batch to the Others', xlabel='batch');
As expected, the cross-entropy increases with the differences of the mean. If you want to have an introduction and enlightening discussion on the concept of entropy in probability theory I would recommend the (fantastic!) book Statistical Rethinking by Richard McElreath.
- KL Divergence
We generate a similar plot for the Kullback–Leibler divergence.
fig, ax1 = plt.subplots()
sns.violinplot(data=normal_samples, ax=ax1)
ax2 = ax1.twinx()
sns.lineplot(x=range(normals.batch_shape[0]), y=normals[0].kl_divergence(normals), color='black', alpha=0.5, ax=ax2)
sns.scatterplot(x=range(normals.batch_shape[0]), y=normals[0].kl_divergence(normals), s=80, color='black', label='Kullback--Leibler Divergence(first batch to others)', ax=ax2)
ax2.grid(None)
ax2.legend(loc='lower right')
ax2.tick_params(axis='y', labelcolor='black')
ax2.set(ylabel='Kullback-Leibler Divergence')
ax1.set(title='Kullback-Leibler Divergence from First Batch to the Others', xlabel='batch');
Multivariate Normal Distribution
We can easily sample from a Multivariate Normal distribution (see here for more details). Here, the MultivariateNormalTriL
class receives as input the mean vector and the Cholesky decomposition of the covariance matrix.
mu = [1.0, 2.0]
cov = [[2.0, 1.0],
[1.0, 1.0]]
cholesky = tf.linalg.cholesky(cov)
multi_normal = tfd.MultivariateNormalTriL(loc=mu, scale_tril=cholesky)
multi_normal_samples = multi_normal.sample(n_samples)
Let us plot the samples:
g = sns.jointplot(
x=multi_normal_samples[:,0],
y=multi_normal_samples[:, 1],
space=0,
height=5,
)
g.fig.suptitle('Multinormal Samples', y=1.04);
Now let us plot the (marginal) KDE.
g = sns.JointGrid(
x=multi_normal_samples[:,0],
y=multi_normal_samples[:, 1],
space=0,
height=5
)
g = g.plot_joint(sns.kdeplot, cmap='Blues_d')
g = g.plot_marginals(sns.kdeplot, shade=True)
g.fig.suptitle('Multinormal Samples (KDE Plot)', y=1.04);
Gaussian Process
A very useful distribution is the Gaussian Process. For more details on this for of distribution you can see An Introduction to Gaussian Process Regression.
- ExponentiatedQuadratic Kernel
# Let us consider a Gaussian Process with Exponential Quadratic kernel.
eq_kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amplitude=0.1, length_scale=0.5)
# Define gird (X).
xs = tf.reshape(tf.linspace(0.0, 10.0, 100), [-1, 1])
# Define Gaussian Process object.
gp_eq = tfd.GaussianProcess(kernel=eq_kernel, index_points=xs)
# Sample from the Gaussian Process.
gp_eq_samples = gp_eq.sample(7)
Now we plot each sample. Recall that this will give us functions defined on the grid xs
.
fig, ax = plt.subplots()
for i in range(7):
sns.lineplot(
x=xs[..., 0],
y=gp_eq_samples[i, :],
color=sns_c[i],
ax=ax
)
sns.lineplot(x=xs[..., 0], y=gp_eq.mean(), color='black', ax=ax)
# Compatibility interval.
upper, lower = gp_eq.mean() + [2 * gp_eq.stddev(), -2 * gp_eq.stddev()]
ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.2, label=r'gp_mean $\pm$ 2std')
ax.legend()
ax.set(title='Gaussian Process Samples (ExponentiatedQuadratic Kernel)');
- ExpSinSquared Kernel
# ExpSinSquared (periodic) Kernel.
es_kernel = tfp.math.psd_kernels.ExpSinSquared(amplitude=0.5, length_scale=1.5, period=3)
# Define gird (X).
xs = tf.reshape(tf.linspace(0.0, 10.0, 80), [-1, 1])
# Define Gaussian Process object.
gp_es = tfd.GaussianProcess(kernel=es_kernel, index_points=xs)
# Sample from Gaussian Process.
gp_es_samples = gp_es.sample(7)
We plot the samples (note that each sample is indeed a periodic function).
fig, ax = plt.subplots()
for i in range(7):
sns.lineplot(
x=xs[..., 0],
y=gp_es_samples[i, :],
color=sns_c[i],
ax=ax
)
sns.lineplot(x=xs[..., 0], y=gp_es.mean(), color='black', ax=ax)
# Compatibility interval.
upper, lower = gp_es.mean() + [2 * gp_es.stddev(), -2 * gp_es.stddev()]
ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.2, label=r'gp_mean $\pm$ 2std')
ax.legend()
ax.set(title='Gaussian Process Samples (ExpSinSquared Kernel)');
- Gaussian Process Regression
Let us illustrate on a simple toy model how to solve an interpolation problem using a Gaussian Process Regression.
# Define true generating function on the grid.
ys = tf.math.sin(1*np.pi*xs) + tf.math.sin(0.5*np.pi*xs)
# Get 20 random points.
n_obs = 20
indices = np.random.uniform(low=0.0, high=xs.shape[0] - 1, size=n_obs )
indices = [int(x) for x in np.round(indices)]
x_obs = tf.gather(params=xs, indices=indices)
y_obs = tf.gather(params=ys, indices=indices)
# Add Gaussian noise.
y_obs = y_obs[..., 0] + tf.random.normal(mean=0.0, stddev=0.05, shape=(n_obs, ))
# Plot.
fig, ax = plt.subplots()
sns.lineplot(x=xs[..., 0], y=ys[..., 0], color=sns_c[0], alpha=0.3, ax=ax)
sns.scatterplot(x=x_obs[..., 0], y=y_obs, color=sns_c[0], s=20, edgecolor=sns_c[0], ax=ax)
ax.set(title='Sample Data');
Assuming we know the data is periodic (up to Gaussian noise), we can use a ExpSinSquared
kernel to model the data.
kernel = tfp.math.psd_kernels.ExpSinSquared(amplitude=0.5, length_scale=1.0, period=4)
gprm = tfd.GaussianProcessRegressionModel(
kernel=kernel,
index_points=xs,
observation_index_points=x_obs,
observations=y_obs,
observation_noise_variance=0.05**2,
)
fig, ax = plt.subplots()
# Sample many times to get fits.
for _ in range(25):
sns.lineplot(x=xs[..., 0], y=gprm.sample(), c=sns_c[3], alpha=0.1, ax=ax)
# Compatibility interval.
upper, lower = gprm.mean() + [2 * gprm.stddev(), -2 * gprm.stddev()]
sns.lineplot(x=xs[..., 0], y=ys[..., 0], color=sns_c[0], alpha=0.3, ax=ax)
sns.scatterplot(x=x_obs[..., 0], y=y_obs, color=sns_c[0], s=40, edgecolor=sns_c[0], ax=ax)
ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.3, label=r'gp_mean $\pm$ 2std')
ax.set(title='Gaussian Process Regression Fit - ExpSinSquared Kernel');