8 min read

A Glimpse into TensorFlow Probability Distributions

In this notebook we want to go take a look into the distributions module of TensorFlow probability. The aim is to understand the fundamentals and then explore further this probabilistic programming framework. Here you can find an overview of TensorFlow Probability. We will concentrate on the first part of Layer 1: Statistical Building Blocks. As you could see from the distributions module documentation, there are many classes of distributions. We will explore a small sample of them in order to get an overall overview. I find the documentation itself a great place to start. In addition, there is a sample of notebooks with concrete examples on the GitHub repository. In particular, I will follow some of cases presented on the A_Tour_of_TensorFlow_Probability notebook, expand on some details and add some other additional examples.

Prepare Notebook

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

# Data Viz. 
import matplotlib.pyplot as plt
import seaborn as sns
    rc={'axes.facecolor': '.9', 'grid.color': '.8'}
sns_c = sns.color_palette(palette='deep')
%matplotlib inline
from pandas.plotting import register_matplotlib_converters

plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['figure.dpi'] = 100
# Get TensorFlow version.
print(f'TnesorFlow version: {tf.__version__}')
print(f'TnesorFlow Probability version: {tfp.__version__}')
TnesorFlow version: 2.2.0
TnesorFlow Probability version: 0.10.0


Let us consider a couple of well-know distributions to begin:

normal = tfd.Normal(loc=0.0, scale=1.0)
gamma = tfd.Gamma(concentration=5.0, rate=1.0)
poisson = tfd.Poisson(rate=2.0)
laplace = tfd.Laplace(loc=0.0, scale=1.0)

Let us sample from each of them and visualize their distributions.

n_samples = 800

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes = axes.flatten()

sns.distplot(a=normal.sample(n_samples), color=sns_c[0], rug=True, ax=axes[0])
axes[0].set(title=f'Normal Distribution')

sns.distplot(a=gamma.sample(n_samples), color=sns_c[1], rug=True, ax=axes[1])
axes[1].set(title=f'Gamma Distribution');

sns.distplot(a=poisson.sample(n_samples), color=sns_c[2], kde=False, rug=True, ax=axes[2])
axes[2].set(title='Poisson Distribution');

sns.distplot(a=laplace.sample(n_samples), color=sns_c[3], rug=True, ax=axes[3])
axes[3].set(title='Laplace Distribution')

plt.suptitle(f'Distribution Samples ({n_samples})', y=0.95);

We treat distributions as tensors, which can have many dimensions. In particular,

  • Batch shape denotes a collection of Distributions with distinct parameters.

  • Event shape denotes the shape of samples from the Distribution.

As a convention, batch shapes are on the “left” and event shapes on the “right”1.

We can take samples based on a tensor of sizes:

normal_samples = normal.sample([n_samples, n_samples])

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes = axes.flatten()

for i in range(2):
    sns.distplot(a=normal_samples[i], color=sns_c[0], rug=True, ax=axes[i])
    axes[i].set(title=f'Normal Distribution (Iter {i})')
plt.suptitle(f'Distribution Samples ({n_samples})', y=0.97);

We can compute common stats of the distribution samples:

x = tf.linspace(start=-5.0, stop=5.0, num=100)

fig, ax1 = plt.subplots()
ax2 = ax1.twinx() 
sns.lineplot(x=x, y=normal.cdf(x), color=sns_c[3], label='cdf', ax=ax1)
sns.distplot(a=normal.sample(n_samples), color=sns_c[0], label='samples', rug=True, ax=ax2)
ax1.axvline(x=0.0, color=sns_c[2], linestyle='--', label=r'$\mu=0$')

q_list = [0.05, 0.95]
quantiles = normal.quantile(q_list).numpy()
for i, q in zip(q_list , quantiles):
    ax1.axvline(x=q, color=sns_c[1], linestyle='--', label=f'quantile {i}')

ax1.tick_params(axis='y', labelcolor=sns_c[0])
ax2.tick_params(axis='y', labelcolor=sns_c[3])
ax1.legend(loc='upper left')
ax2.legend(loc='center left')
ax1.set(title='Normal Distribution');

Next, let us consider a sequence of normal distributions with fixed standard deviation and increasing mean.

loc_list = np.linspace(start=0.0, stop=8.0, num=5)
normals = tfd.Normal(loc=loc_list, scale=1.0)
normal_samples = normals.sample(n_samples)
TensorShape([800, 5])
fig, ax = plt.subplots()
for i in range(normals.batch_shape[0]):
    sns.distplot(a=normal_samples[:, i], ax=ax)
ax.set(title='Batch Samples Normal Distribution');
  • Entropy

Let us compute the (Shannon) entropy of each distribution.

fig, ax = plt.subplots(figsize=(8, 6))
sns.barplot(x=list(range(normals.batch_shape[0])), y=normals.entropy(), ax=ax)
ax.set(title='Entropy', xlabel='batch');

As expected, these values remain do not depend on the mean. Indeed, the entropy for a normal distribution just depends on the standard deviation. The explicit value can be computed as:

  • Cross Entropy

We now compute the cross-entropy from the first normal distribution to the rest. We display the sample distributions as violin plots.

fig, ax1 = plt.subplots()
sns.violinplot(data=normal_samples, ax=ax1)
ax2 = ax1.twinx()
    label='cross entropy (first batch to others)', 
ax2.legend(loc='lower right')
ax2.tick_params(axis='y', labelcolor='black')
ax2.set(ylabel='cross entropy')
ax1.set(title='Cross-Entropy from First Batch to the Others', xlabel='batch');

As expected, the cross-entropy increases with the differences of the mean. If you want to have an introduction and enlightening discussion on the concept of entropy in probability theory I would recommend the (fantastic!) book Statistical Rethinking by Richard McElreath.

  • KL Divergence

We generate a similar plot for the Kullback–Leibler divergence.

fig, ax1 = plt.subplots()
sns.violinplot(data=normal_samples, ax=ax1)
ax2 = ax1.twinx()
sns.lineplot(x=range(normals.batch_shape[0]), y=normals[0].kl_divergence(normals), color='black',  alpha=0.5, ax=ax2)
sns.scatterplot(x=range(normals.batch_shape[0]), y=normals[0].kl_divergence(normals), s=80, color='black', label='Kullback--Leibler Divergence(first batch to others)', ax=ax2)
ax2.legend(loc='lower right')
ax2.tick_params(axis='y', labelcolor='black')
ax2.set(ylabel='Kullback-Leibler Divergence')
ax1.set(title='Kullback-Leibler Divergence from First Batch to the Others', xlabel='batch');

Multivariate Normal Distribution

We can easily sample from a Multivariate Normal distribution (see here for more details). Here, the MultivariateNormalTriL class receives as input the mean vector and the Cholesky decomposition of the covariance matrix.

mu = [1.0, 2.0]
cov = [[2.0, 1.0],
       [1.0, 1.0]]
cholesky = tf.linalg.cholesky(cov)
multi_normal = tfd.MultivariateNormalTriL(loc=mu, scale_tril=cholesky)
multi_normal_samples = multi_normal.sample(n_samples)

Let us plot the samples:

g = sns.jointplot(
    y=multi_normal_samples[:, 1], 
g.fig.suptitle('Multinormal Samples', y=1.04);

Now let us plot the (marginal) KDE.

g = sns.JointGrid(
    y=multi_normal_samples[:, 1], 
g = g.plot_joint(sns.kdeplot, cmap='Blues_d')
g = g.plot_marginals(sns.kdeplot, shade=True)
g.fig.suptitle('Multinormal Samples (KDE Plot)', y=1.04);

Gaussian Process

A very useful distribution is the Gaussian Process. For more details on this for of distribution you can see An Introduction to Gaussian Process Regression.

  • ExponentiatedQuadratic Kernel
# Let us consider a Gaussian Process with Exponential Quadratic kernel. 
eq_kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amplitude=0.1, length_scale=0.5)
# Define gird (X).
xs = tf.reshape(tf.linspace(0.0, 10.0, 100), [-1, 1])
# Define Gaussian Process object.
gp_eq = tfd.GaussianProcess(kernel=eq_kernel, index_points=xs)
# Sample from the Gaussian Process. 
gp_eq_samples = gp_eq.sample(7)

Now we plot each sample. Recall that this will give us functions defined on the grid xs.

fig, ax = plt.subplots()
for i in range(7):
        x=xs[..., 0], 
        y=gp_eq_samples[i, :], 
sns.lineplot(x=xs[..., 0], y=gp_eq.mean(), color='black', ax=ax)
# Compatibility interval.
upper, lower = gp_eq.mean() + [2 * gp_eq.stddev(), -2 * gp_eq.stddev()]
ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.2, label=r'gp_mean $\pm$ 2std')
ax.set(title='Gaussian Process Samples (ExponentiatedQuadratic Kernel)');
  • ExpSinSquared Kernel
# ExpSinSquared (periodic) Kernel.
es_kernel = tfp.math.psd_kernels.ExpSinSquared(amplitude=0.5, length_scale=1.5, period=3)
# Define gird (X).
xs = tf.reshape(tf.linspace(0.0, 10.0, 80), [-1, 1])
# Define Gaussian Process object.
gp_es = tfd.GaussianProcess(kernel=es_kernel, index_points=xs)
# Sample from Gaussian Process.
gp_es_samples = gp_es.sample(7)

We plot the samples (note that each sample is indeed a periodic function).

fig, ax = plt.subplots()
for i in range(7):
        x=xs[..., 0], 
        y=gp_es_samples[i, :], 
sns.lineplot(x=xs[..., 0], y=gp_es.mean(), color='black', ax=ax)
# Compatibility interval.
upper, lower = gp_es.mean() + [2 * gp_es.stddev(), -2 * gp_es.stddev()]
ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.2, label=r'gp_mean $\pm$ 2std')
ax.set(title='Gaussian Process Samples (ExpSinSquared Kernel)');
  • Gaussian Process Regression

Let us illustrate on a simple toy model how to solve an interpolation problem using a Gaussian Process Regression.

# Define true generating function on the grid.
ys = tf.math.sin(1*np.pi*xs) + tf.math.sin(0.5*np.pi*xs)
# Get 20 random points. 
n_obs = 20
indices = np.random.uniform(low=0.0, high=xs.shape[0] - 1, size=n_obs )
indices = [int(x) for x in np.round(indices)]
x_obs = tf.gather(params=xs, indices=indices)
y_obs = tf.gather(params=ys, indices=indices)
# Add Gaussian noise. 
y_obs = y_obs[..., 0] + tf.random.normal(mean=0.0, stddev=0.05, shape=(n_obs, ))

# Plot.
fig, ax = plt.subplots()
sns.lineplot(x=xs[..., 0], y=ys[..., 0], color=sns_c[0], alpha=0.3, ax=ax)
sns.scatterplot(x=x_obs[..., 0], y=y_obs, color=sns_c[0], s=20, edgecolor=sns_c[0], ax=ax)
ax.set(title='Sample Data');

Assuming we know the data is periodic (up to Gaussian noise), we can use a ExpSinSquared kernel to model the data.

kernel = tfp.math.psd_kernels.ExpSinSquared(amplitude=0.5, length_scale=1.0, period=4)

gprm = tfd.GaussianProcessRegressionModel(
fig, ax = plt.subplots()
# Sample many times to get fits. 
for _ in range(25):
    sns.lineplot(x=xs[..., 0], y=gprm.sample(), c=sns_c[3], alpha=0.1, ax=ax)
# Compatibility interval.
upper, lower = gprm.mean() + [2 * gprm.stddev(), -2 * gprm.stddev()]

sns.lineplot(x=xs[..., 0], y=ys[..., 0], color=sns_c[0], alpha=0.3, ax=ax)
sns.scatterplot(x=x_obs[..., 0], y=y_obs, color=sns_c[0], s=40, edgecolor=sns_c[0], ax=ax)

ax.fill_between(xs[..., 0], upper, lower, color=sns_c[8], alpha=0.3, label=r'gp_mean $\pm$ 2std')
ax.set(title='Gaussian Process Regression Fit - ExpSinSquared Kernel');

Hidden Markov Model

The final type, of distribution we want to touch is the Hidden MArkov Model. There is a very illustrative example on the distribution documentation where a simple Hidden Markov Model is used to predict temperature as a function of (hidden states) whether its a cold or a hot day. Here we discuss a different (but rather well-known) example studied on this presentation by Ben Langmead.

Example: Dealer repeatedly flips a coin. Sometimes the coin is fair, with P(heads) = 0.5, sometimes it’s loaded, with P(heads) = 0.8. Between each flip, dealer switches coins (invisibly) with prob. 0.4.

# We assume there is a 50/50 probability that the dealer begins with the fair coin. 
initial_distribution = tfd.Categorical(probs=[0.5, 0.5])
# Transition state matrix.
transition_distribution = tfd.Categorical(
    probs=[[0.6, 0.4],
           [0.4, 0.6]]
# Transition matrix of the observed states. 
observation_distribution = tfd.Categorical(
    probs=[[0.5, 0.5],
           [0.8, 0.2]]
# Define Hidden Markov Model.
hmm_model = tfd.HiddenMarkovModel(

Let us sample from this model (sequences of length 100).

hmm_samples = hmm_model.sample(10000)

Let us compute the frequency of tails for each sequence.

hmm_samples_mean = tf.math.reduce_mean(input_tensor=tf.cast(x=hmm_samples, dtype='float32'), axis=1)
# Let us plot the distribution.
hmm_samples_mean_mean = tf.math.reduce_mean(input_tensor=hmm_samples_mean, axis=0).numpy()
hmm_samples_mean_std = tf.math.reduce_std(input_tensor=hmm_samples_mean, axis=0).numpy()
hmm_samples_plus = hmm_samples_mean_mean + 2*hmm_samples_mean_std
hmm_samples_minus = hmm_samples_mean_mean - 2*hmm_samples_mean_std

fig, ax = plt.subplots()
sns.distplot(a=hmm_samples_mean, rug=True, ax=ax)
ax.axvline(x=hmm_samples_mean_mean, color=sns_c[2], linestyle='--', label=f'$\mu$ = {hmm_samples_mean_mean: 0.3f}')
ax.axvline(x=hmm_samples_plus, color=sns_c[1], linestyle='--', label=f'$\mu + 2\sigma$ ={hmm_samples_plus: 0.3f}')
ax.axvline(x=hmm_samples_minus, color=sns_c[1], linestyle='--', label=f'$\mu - 2\sigma$ ={hmm_samples_minus: 0.3f}')

ax.set(title='Hidden Markov Model Samples Mean;');

Observe that the mean sample probability of getting tails is ~ 0.35 < 0.5.

We could also ask the question: What is the probability of getting 10 heads in a row?

# We can simply run the `prob` method.
hmm_model.prob(tf.zeros(10, dtype='int32'))
<tf.Tensor: shape=(), dtype=float32, numpy=0.015090796>

The answer is around 1.6%. We could ask ourselves the same question for a fair coin:

tfd.Binomial(total_count=10, probs=0.5).prob(value=0)
<tf.Tensor: shape=(), dtype=float32, numpy=0.0009765625>

This value is much more smaller as expected.

This was just a glimpse of the distributions module of tensorflow probability. At the moment we did not do anything fancy, but this is the foundation for inference and prediction problems later to come.