Skip to content

Commit

Permalink
Add Pyro example (#109)
Browse files Browse the repository at this point in the history
* Add pyro example

* Add readme

* Fix VI with tree_map
  • Loading branch information
SamDuffield authored Aug 12, 2024
1 parent 55dd3e9 commit 1a7eeba
Show file tree
Hide file tree
Showing 12 changed files with 806 additions and 2 deletions.
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas
- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
data.
- ['pyro_pima_indians`](pyro_pima_indians/): Uses `pyro` to define a Bayesian logistic
regression model for the [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database).
Compares `posteriors` methods against `pyro` and `blackjax`.
- [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy
exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).
- [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628)
Expand Down
49 changes: 49 additions & 0 deletions examples/pyro_pima_indians/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Using `posteriors` with `pyro`

In this example, we show how to use `posteriors` with [`pyro`](https://pyro.ai/) to define a Bayesian logistic regression model for the [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database). The model can then be used to automatically generate
a `log_posterior` function that can be directly passed to `posteriors`.

This specific model is small with 8 dimensions and 768 data points.


## Results


<p align="center">
<img src="https://storage.googleapis.com/posteriors/pima_indians_full_batch_marginals.png" width=60%">
<img src="https://storage.googleapis.com/posteriors/pima_indians_mini_batch_marginals.png" width=60%">
<br>
<em>Marginal posterior densities for a variety of methods and package
implementations. Top: Full batch, bottom: mini-batched.</em>
</p>


In the above figure, we show the marginal posterior densities for a variety of methods and package implementations. We observe a broad agreement between approximations indicating all methods have converged, taking Pyro's [NUTS](http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf) implementation as a gold standard.


<p align="center">
<img src="https://storage.googleapis.com/posteriors/pima_indians_metrics.png" width=45%">
<br>
<em>Kernelized Stein discrepancy (KSD) measures the distance between the samples provided by the algorithm and the true posterior via a kernel function (in this case a standard Gaussian). All results are averaged over 10 random seeds with one standard deviation displayed. †The displayed parallel SGHMC time represents the time for a single chain that could be obtained with sufficient parallel resources.</em>
</p>

In the table above we compare kernelized Stein discrepancies between methods as a quantitative
measure of the distance between the collected samples and the true posterior, confirming the
qualitative observations from the marginal posterior densities that `posteriors` methods are
competitive and suitably converged. For this small scale example the Python overheads are
significant, as demonstrated for the only minor speedup due to minibatching
(and overheads are less so for JAX in this setting although in our experience this
rapidly deteriorates for larger models).

## Code Structure

- [`model.py`](model.py): Defines the Bayesian logistic regression model using `pyro`
and loads the required functions for inference in `torch` or `jax`.
- The run files are end-to-end implementations (aside from the model loading) for the
labeled packages and methods.
- [`plot_marginals.py`](plot_marginals.py): Plots the marginal posterior densities in
the figure above.
- [`calculate_metrics.py`](calculate_metrics.py): Calculates the metrics in the table
above using the functions in [`ksd.py`](ksd.py) to calculate the kernelized Stein
discrepancies.

75 changes: 75 additions & 0 deletions examples/pyro_pima_indians/calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pickle
import pprint
import torch
import numpy as np
from tqdm import tqdm

from examples.pyro_pima_indians.model import load_data, load_model
from examples.pyro_pima_indians.ksd import ksd

name_dict = {
"Pyro (NUTS)": "examples/pyro_pima_indians/results/pyro.pkl",
"BlackJAX (SGHMC, Full Batch)": "examples/pyro_pima_indians/results/blackjax_sghmc.pkl",
"posteriors (SGHMC, Full Batch)": "examples/pyro_pima_indians/results/posteriors_sghmc_None.pkl",
"posteriors (SGHMC, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_sghmc_32.pkl",
"posteriors (VI, Full Batch)": "examples/pyro_pima_indians/results/posteriors_vi_None.pkl",
"posteriors (VI, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_vi_32.pkl",
"posteriors (Parallel SGHMC, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_sghmc_parallel_32.pkl",
}


sample_dict = {}
time_dict = {}
for key, dir in name_dict.items():
with open(dir, "rb") as f:
save_dict = pickle.load(f)
sample_dict[key] = save_dict["samples"]
time_dict[key] = save_dict["times"]


time_summs_dict = {
key: (time_dict[key].mean(), time_dict[key].std()) for key in time_dict
}

# Print the mean and standard deviation of the time taken for each method
pprint.pprint(time_summs_dict)


# Stein discrepancy
# The Stein discrepancy is a measure of the difference between a collection of samples and a log posterior function.

ksd_batchsize = 100
ksd_save_dir = "examples/pyro_pima_indians/results/ksd.pickle"

X_all, y_all = load_data()
dim = X_all.shape[1]
num_data = X_all.shape[0]

model, log_posterior_normalized = load_model(num_data)


def grad_log_posterior_normalized(params):
return torch.func.grad(lambda p: log_posterior_normalized(p, (X_all, y_all))[0])(
params
)


def ksd_via_grads(samples, batchsize=None):
gradients = torch.stack([grad_log_posterior_normalized(s) for s in samples])
return ksd(samples, gradients, batchsize=batchsize)


ksd_dict = {}
for key, samples in tqdm(sample_dict.items()):
ksd_dict[key] = np.array(
[ksd_via_grads(torch.tensor(s), batchsize=ksd_batchsize) for s in samples]
)

ksd_summs_dict = {key: (ksd_dict[key].mean(), ksd_dict[key].std()) for key in ksd_dict}


# Print the mean and standard deviation of the KSD for each method
pprint.pprint(ksd_summs_dict)

with open(ksd_save_dir, "wb") as f:
pickle.dump(ksd_summs_dict, f)
47 changes: 47 additions & 0 deletions examples/pyro_pima_indians/ksd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch


def gaussian_kernel(x, y):
return torch.exp(-0.5 * (x - y).pow(2).sum(-1))


def gaussian_kernel_dx(x, y):
return (x - y) * gaussian_kernel(x, y)


def gaussian_kernel_dy(x, y):
return -gaussian_kernel_dx(x, y)


def gaussian_kernel_diag_dxdy(x, y):
return (1 - (x - y).pow(2)) * gaussian_kernel(x, y)


def ksd(samples, gradients, batchsize=None):
n = samples.shape[0]
if batchsize is None:

def get_batch_inds():
return torch.arange(n)
else:

def get_batch_inds():
return torch.randint(n, size=(batchsize,))

def k0(sampsi, sampsj, gradsi, gradsj):
return (
torch.sum(gaussian_kernel_diag_dxdy(sampsi, sampsj))
+ torch.dot(gaussian_kernel_dx(sampsi, sampsj), gradsj)
+ torch.dot(gaussian_kernel_dy(sampsi, sampsj), gradsi)
+ torch.dot(gradsi, gradsj) * gaussian_kernel(sampsi, sampsj)
)

def v_k_0(sampsi, gradsi):
batch_inds = get_batch_inds()
return torch.vmap(k0, in_dims=(None, 0, None, 0))(
sampsi, samples[batch_inds], gradsi, gradients[batch_inds]
).mean()

return torch.sqrt(
torch.vmap(v_k_0, randomness="different")(samples, gradients).mean()
)
82 changes: 82 additions & 0 deletions examples/pyro_pima_indians/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import pandas as pd
import pyro
import pyro.distributions as dist
import jax


def load_data():
# Load the Pima Indians Diabetes dataset
data_url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"
column_names = [
"num_pregnant",
"glucose_concentration",
"blood_pressure",
"skin_thickness",
"serum_insulin",
"bmi",
"diabetes_pedigree",
"age",
"class",
]
data = pd.read_csv(data_url, header=None, names=column_names)

# Preprocess the data
X_all = data.drop(columns=["class"]).values
y_all = data["class"].values

# Normalize the data
X_mean = X_all.mean(axis=0)
X_std = X_all.std(axis=0)
X_all = (X_all - X_mean) / X_std

# Convert to torch tensors
X_all = torch.tensor(X_all, dtype=torch.float)
y_all = torch.tensor(y_all, dtype=torch.float)
return X_all, y_all


def load_model(num_data):
# Define the logistic regression model with pyro
def model(data):
X, y = data

batchsize = X.shape[0]

# Define the priors
w = pyro.sample(
"w",
dist.Normal(torch.zeros(X.shape[1]), scale=(num_data / batchsize) ** 0.5),
) # Scale to ensure the prior variance is 1 for all batch sizes

# Define the logistic regression model
logits = torch.matmul(X, w)
y_pred = torch.sigmoid(logits)

return pyro.sample("obs", dist.Bernoulli(y_pred), obs=y)

# Define the log posterior function using Pyro's tracing utilities
def log_posterior_normalized(params, batch):
X, y = batch
batchsize = X.shape[0]
conditioned_model = pyro.condition(model, data={"w": params})
model_trace = pyro.poutine.trace(conditioned_model).get_trace((X, y))
log_joint = model_trace.log_prob_sum()
return log_joint / batchsize, torch.tensor([])

return model, log_posterior_normalized


def load_jax_model(num_data):
def jax_log_posterior_normalized(params, batch):
X, y = batch
batch_size = X.shape[0]
logits = jax.numpy.matmul(X, params)
y_pred = jax.nn.sigmoid(logits)
return (
jax.numpy.sum(jax.numpy.log(y_pred * y + (1 - y_pred) * (1 - y)), axis=0)
/ batch_size
+ jax.scipy.stats.norm.logpdf(params).sum() / num_data
)

return jax_log_posterior_normalized
102 changes: 102 additions & 0 deletions examples/pyro_pima_indians/plot_marginals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

plt.ion()

plt.rcParams["font.family"] = "Times New Roman"

plot_name = "full_batch"
name_dict = {
"Pyro (NUTS)": ["examples/pyro_pima_indians/results/pyro.pkl", "grey"],
"BlackJAX (SGHMC, Full Batch)": [
"examples/pyro_pima_indians/results/blackjax_sghmc.pkl",
"black",
],
"posteriors (SGHMC, Full Batch)": [
"examples/pyro_pima_indians/results/posteriors_sghmc_None.pkl",
"firebrick",
],
"posteriors (VI, Full Batch)": [
"examples/pyro_pima_indians/results/posteriors_vi_None.pkl",
"forestgreen",
],
}

# plot_name = "mini_batch"
# name_dict = {
# "Pyro (NUTS)": ["examples/pyro_pima_indians/results/pyro.pkl", "grey"],
# "posteriors (SGHMC, Batch Size=32)": [
# "examples/pyro_pima_indians/results/posteriors_sghmc_32.pkl",
# "firebrick",
# ],
# "posteriors (VI, Batch Size=32)": [
# "examples/pyro_pima_indians/results/posteriors_vi_32.pkl",
# "forestgreen",
# ],
# "posteriors (Parallel SGHMC, Batch Size=32)": [
# "examples/pyro_pima_indians/results/posteriors_sghmc_parallel_32.pkl",
# "tomato",
# ],
# }


column_names = [
"num_pregnant",
"glucose_concentration",
"blood_pressure",
"skin_thickness",
"serum_insulin",
"bmi",
"diabetes_pedigree",
"age",
"class",
]


sample_dict = {}
for key, (dir, _) in name_dict.items():
with open(dir, "rb") as f:
save_dict = pickle.load(f)
sample_dict[key] = save_dict["samples"]


# Plot the marginals
samp_index = 1


def plot_kernel_density(samples, ax, label, color=None):
sns.kdeplot(samples, ax=ax, label=label, color=color, bw_adjust=1.5)


fig, axes = plt.subplots(2, 4, figsize=(10, 5))


for dim_ind, ax in enumerate(axes.ravel()):
for k in name_dict.keys():
samps = sample_dict[k][samp_index][:, dim_ind]

plot_kernel_density(
sample_dict[k][samp_index][:, dim_ind],
ax,
label=k if dim_ind == 0 else None,
color=name_dict[k][1],
)
# if dim_ind == 7:
# ax.legend()
ax.set_xlabel(column_names[dim_ind])

# Remove y-axis label and ticks
ax.set_ylabel("")
ax.set_yticks([])

# Remove frames
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)

fig.legend(framealpha=0.3, frameon=True)


fig.tight_layout()
fig.savefig(f"examples/pyro_pima_indians/results/{plot_name}_marginals.png", dpi=400)
Loading

0 comments on commit 1a7eeba

Please sign in to comment.