-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add pyro example * Add readme * Fix VI with tree_map
- Loading branch information
1 parent
55dd3e9
commit 1a7eeba
Showing
12 changed files
with
806 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Using `posteriors` with `pyro` | ||
|
||
In this example, we show how to use `posteriors` with [`pyro`](https://pyro.ai/) to define a Bayesian logistic regression model for the [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database). The model can then be used to automatically generate | ||
a `log_posterior` function that can be directly passed to `posteriors`. | ||
|
||
This specific model is small with 8 dimensions and 768 data points. | ||
|
||
|
||
## Results | ||
|
||
|
||
<p align="center"> | ||
<img src="https://storage.googleapis.com/posteriors/pima_indians_full_batch_marginals.png" width=60%"> | ||
<img src="https://storage.googleapis.com/posteriors/pima_indians_mini_batch_marginals.png" width=60%"> | ||
<br> | ||
<em>Marginal posterior densities for a variety of methods and package | ||
implementations. Top: Full batch, bottom: mini-batched.</em> | ||
</p> | ||
|
||
|
||
In the above figure, we show the marginal posterior densities for a variety of methods and package implementations. We observe a broad agreement between approximations indicating all methods have converged, taking Pyro's [NUTS](http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf) implementation as a gold standard. | ||
|
||
|
||
<p align="center"> | ||
<img src="https://storage.googleapis.com/posteriors/pima_indians_metrics.png" width=45%"> | ||
<br> | ||
<em>Kernelized Stein discrepancy (KSD) measures the distance between the samples provided by the algorithm and the true posterior via a kernel function (in this case a standard Gaussian). All results are averaged over 10 random seeds with one standard deviation displayed. †The displayed parallel SGHMC time represents the time for a single chain that could be obtained with sufficient parallel resources.</em> | ||
</p> | ||
|
||
In the table above we compare kernelized Stein discrepancies between methods as a quantitative | ||
measure of the distance between the collected samples and the true posterior, confirming the | ||
qualitative observations from the marginal posterior densities that `posteriors` methods are | ||
competitive and suitably converged. For this small scale example the Python overheads are | ||
significant, as demonstrated for the only minor speedup due to minibatching | ||
(and overheads are less so for JAX in this setting although in our experience this | ||
rapidly deteriorates for larger models). | ||
|
||
## Code Structure | ||
|
||
- [`model.py`](model.py): Defines the Bayesian logistic regression model using `pyro` | ||
and loads the required functions for inference in `torch` or `jax`. | ||
- The run files are end-to-end implementations (aside from the model loading) for the | ||
labeled packages and methods. | ||
- [`plot_marginals.py`](plot_marginals.py): Plots the marginal posterior densities in | ||
the figure above. | ||
- [`calculate_metrics.py`](calculate_metrics.py): Calculates the metrics in the table | ||
above using the functions in [`ksd.py`](ksd.py) to calculate the kernelized Stein | ||
discrepancies. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import pickle | ||
import pprint | ||
import torch | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from examples.pyro_pima_indians.model import load_data, load_model | ||
from examples.pyro_pima_indians.ksd import ksd | ||
|
||
name_dict = { | ||
"Pyro (NUTS)": "examples/pyro_pima_indians/results/pyro.pkl", | ||
"BlackJAX (SGHMC, Full Batch)": "examples/pyro_pima_indians/results/blackjax_sghmc.pkl", | ||
"posteriors (SGHMC, Full Batch)": "examples/pyro_pima_indians/results/posteriors_sghmc_None.pkl", | ||
"posteriors (SGHMC, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_sghmc_32.pkl", | ||
"posteriors (VI, Full Batch)": "examples/pyro_pima_indians/results/posteriors_vi_None.pkl", | ||
"posteriors (VI, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_vi_32.pkl", | ||
"posteriors (Parallel SGHMC, Batch Size=32)": "examples/pyro_pima_indians/results/posteriors_sghmc_parallel_32.pkl", | ||
} | ||
|
||
|
||
sample_dict = {} | ||
time_dict = {} | ||
for key, dir in name_dict.items(): | ||
with open(dir, "rb") as f: | ||
save_dict = pickle.load(f) | ||
sample_dict[key] = save_dict["samples"] | ||
time_dict[key] = save_dict["times"] | ||
|
||
|
||
time_summs_dict = { | ||
key: (time_dict[key].mean(), time_dict[key].std()) for key in time_dict | ||
} | ||
|
||
# Print the mean and standard deviation of the time taken for each method | ||
pprint.pprint(time_summs_dict) | ||
|
||
|
||
# Stein discrepancy | ||
# The Stein discrepancy is a measure of the difference between a collection of samples and a log posterior function. | ||
|
||
ksd_batchsize = 100 | ||
ksd_save_dir = "examples/pyro_pima_indians/results/ksd.pickle" | ||
|
||
X_all, y_all = load_data() | ||
dim = X_all.shape[1] | ||
num_data = X_all.shape[0] | ||
|
||
model, log_posterior_normalized = load_model(num_data) | ||
|
||
|
||
def grad_log_posterior_normalized(params): | ||
return torch.func.grad(lambda p: log_posterior_normalized(p, (X_all, y_all))[0])( | ||
params | ||
) | ||
|
||
|
||
def ksd_via_grads(samples, batchsize=None): | ||
gradients = torch.stack([grad_log_posterior_normalized(s) for s in samples]) | ||
return ksd(samples, gradients, batchsize=batchsize) | ||
|
||
|
||
ksd_dict = {} | ||
for key, samples in tqdm(sample_dict.items()): | ||
ksd_dict[key] = np.array( | ||
[ksd_via_grads(torch.tensor(s), batchsize=ksd_batchsize) for s in samples] | ||
) | ||
|
||
ksd_summs_dict = {key: (ksd_dict[key].mean(), ksd_dict[key].std()) for key in ksd_dict} | ||
|
||
|
||
# Print the mean and standard deviation of the KSD for each method | ||
pprint.pprint(ksd_summs_dict) | ||
|
||
with open(ksd_save_dir, "wb") as f: | ||
pickle.dump(ksd_summs_dict, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
|
||
|
||
def gaussian_kernel(x, y): | ||
return torch.exp(-0.5 * (x - y).pow(2).sum(-1)) | ||
|
||
|
||
def gaussian_kernel_dx(x, y): | ||
return (x - y) * gaussian_kernel(x, y) | ||
|
||
|
||
def gaussian_kernel_dy(x, y): | ||
return -gaussian_kernel_dx(x, y) | ||
|
||
|
||
def gaussian_kernel_diag_dxdy(x, y): | ||
return (1 - (x - y).pow(2)) * gaussian_kernel(x, y) | ||
|
||
|
||
def ksd(samples, gradients, batchsize=None): | ||
n = samples.shape[0] | ||
if batchsize is None: | ||
|
||
def get_batch_inds(): | ||
return torch.arange(n) | ||
else: | ||
|
||
def get_batch_inds(): | ||
return torch.randint(n, size=(batchsize,)) | ||
|
||
def k0(sampsi, sampsj, gradsi, gradsj): | ||
return ( | ||
torch.sum(gaussian_kernel_diag_dxdy(sampsi, sampsj)) | ||
+ torch.dot(gaussian_kernel_dx(sampsi, sampsj), gradsj) | ||
+ torch.dot(gaussian_kernel_dy(sampsi, sampsj), gradsi) | ||
+ torch.dot(gradsi, gradsj) * gaussian_kernel(sampsi, sampsj) | ||
) | ||
|
||
def v_k_0(sampsi, gradsi): | ||
batch_inds = get_batch_inds() | ||
return torch.vmap(k0, in_dims=(None, 0, None, 0))( | ||
sampsi, samples[batch_inds], gradsi, gradients[batch_inds] | ||
).mean() | ||
|
||
return torch.sqrt( | ||
torch.vmap(v_k_0, randomness="different")(samples, gradients).mean() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import pandas as pd | ||
import pyro | ||
import pyro.distributions as dist | ||
import jax | ||
|
||
|
||
def load_data(): | ||
# Load the Pima Indians Diabetes dataset | ||
data_url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv" | ||
column_names = [ | ||
"num_pregnant", | ||
"glucose_concentration", | ||
"blood_pressure", | ||
"skin_thickness", | ||
"serum_insulin", | ||
"bmi", | ||
"diabetes_pedigree", | ||
"age", | ||
"class", | ||
] | ||
data = pd.read_csv(data_url, header=None, names=column_names) | ||
|
||
# Preprocess the data | ||
X_all = data.drop(columns=["class"]).values | ||
y_all = data["class"].values | ||
|
||
# Normalize the data | ||
X_mean = X_all.mean(axis=0) | ||
X_std = X_all.std(axis=0) | ||
X_all = (X_all - X_mean) / X_std | ||
|
||
# Convert to torch tensors | ||
X_all = torch.tensor(X_all, dtype=torch.float) | ||
y_all = torch.tensor(y_all, dtype=torch.float) | ||
return X_all, y_all | ||
|
||
|
||
def load_model(num_data): | ||
# Define the logistic regression model with pyro | ||
def model(data): | ||
X, y = data | ||
|
||
batchsize = X.shape[0] | ||
|
||
# Define the priors | ||
w = pyro.sample( | ||
"w", | ||
dist.Normal(torch.zeros(X.shape[1]), scale=(num_data / batchsize) ** 0.5), | ||
) # Scale to ensure the prior variance is 1 for all batch sizes | ||
|
||
# Define the logistic regression model | ||
logits = torch.matmul(X, w) | ||
y_pred = torch.sigmoid(logits) | ||
|
||
return pyro.sample("obs", dist.Bernoulli(y_pred), obs=y) | ||
|
||
# Define the log posterior function using Pyro's tracing utilities | ||
def log_posterior_normalized(params, batch): | ||
X, y = batch | ||
batchsize = X.shape[0] | ||
conditioned_model = pyro.condition(model, data={"w": params}) | ||
model_trace = pyro.poutine.trace(conditioned_model).get_trace((X, y)) | ||
log_joint = model_trace.log_prob_sum() | ||
return log_joint / batchsize, torch.tensor([]) | ||
|
||
return model, log_posterior_normalized | ||
|
||
|
||
def load_jax_model(num_data): | ||
def jax_log_posterior_normalized(params, batch): | ||
X, y = batch | ||
batch_size = X.shape[0] | ||
logits = jax.numpy.matmul(X, params) | ||
y_pred = jax.nn.sigmoid(logits) | ||
return ( | ||
jax.numpy.sum(jax.numpy.log(y_pred * y + (1 - y_pred) * (1 - y)), axis=0) | ||
/ batch_size | ||
+ jax.scipy.stats.norm.logpdf(params).sum() / num_data | ||
) | ||
|
||
return jax_log_posterior_normalized |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import pickle | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
plt.ion() | ||
|
||
plt.rcParams["font.family"] = "Times New Roman" | ||
|
||
plot_name = "full_batch" | ||
name_dict = { | ||
"Pyro (NUTS)": ["examples/pyro_pima_indians/results/pyro.pkl", "grey"], | ||
"BlackJAX (SGHMC, Full Batch)": [ | ||
"examples/pyro_pima_indians/results/blackjax_sghmc.pkl", | ||
"black", | ||
], | ||
"posteriors (SGHMC, Full Batch)": [ | ||
"examples/pyro_pima_indians/results/posteriors_sghmc_None.pkl", | ||
"firebrick", | ||
], | ||
"posteriors (VI, Full Batch)": [ | ||
"examples/pyro_pima_indians/results/posteriors_vi_None.pkl", | ||
"forestgreen", | ||
], | ||
} | ||
|
||
# plot_name = "mini_batch" | ||
# name_dict = { | ||
# "Pyro (NUTS)": ["examples/pyro_pima_indians/results/pyro.pkl", "grey"], | ||
# "posteriors (SGHMC, Batch Size=32)": [ | ||
# "examples/pyro_pima_indians/results/posteriors_sghmc_32.pkl", | ||
# "firebrick", | ||
# ], | ||
# "posteriors (VI, Batch Size=32)": [ | ||
# "examples/pyro_pima_indians/results/posteriors_vi_32.pkl", | ||
# "forestgreen", | ||
# ], | ||
# "posteriors (Parallel SGHMC, Batch Size=32)": [ | ||
# "examples/pyro_pima_indians/results/posteriors_sghmc_parallel_32.pkl", | ||
# "tomato", | ||
# ], | ||
# } | ||
|
||
|
||
column_names = [ | ||
"num_pregnant", | ||
"glucose_concentration", | ||
"blood_pressure", | ||
"skin_thickness", | ||
"serum_insulin", | ||
"bmi", | ||
"diabetes_pedigree", | ||
"age", | ||
"class", | ||
] | ||
|
||
|
||
sample_dict = {} | ||
for key, (dir, _) in name_dict.items(): | ||
with open(dir, "rb") as f: | ||
save_dict = pickle.load(f) | ||
sample_dict[key] = save_dict["samples"] | ||
|
||
|
||
# Plot the marginals | ||
samp_index = 1 | ||
|
||
|
||
def plot_kernel_density(samples, ax, label, color=None): | ||
sns.kdeplot(samples, ax=ax, label=label, color=color, bw_adjust=1.5) | ||
|
||
|
||
fig, axes = plt.subplots(2, 4, figsize=(10, 5)) | ||
|
||
|
||
for dim_ind, ax in enumerate(axes.ravel()): | ||
for k in name_dict.keys(): | ||
samps = sample_dict[k][samp_index][:, dim_ind] | ||
|
||
plot_kernel_density( | ||
sample_dict[k][samp_index][:, dim_ind], | ||
ax, | ||
label=k if dim_ind == 0 else None, | ||
color=name_dict[k][1], | ||
) | ||
# if dim_ind == 7: | ||
# ax.legend() | ||
ax.set_xlabel(column_names[dim_ind]) | ||
|
||
# Remove y-axis label and ticks | ||
ax.set_ylabel("") | ||
ax.set_yticks([]) | ||
|
||
# Remove frames | ||
ax.spines["top"].set_visible(False) | ||
ax.spines["right"].set_visible(False) | ||
ax.spines["left"].set_visible(False) | ||
|
||
fig.legend(framealpha=0.3, frameon=True) | ||
|
||
|
||
fig.tight_layout() | ||
fig.savefig(f"examples/pyro_pima_indians/results/{plot_name}_marginals.png", dpi=400) |
Oops, something went wrong.