Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EKF (diagonal) #9

Merged
merged 11 commits into from
Feb 1, 2024
337 changes: 337 additions & 0 deletions examples/yelp/yelp_subspace_ekf_diag_hessian.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/yelp/yelp_subspace_vi_diag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
}
],
"source": [
"# Visualize the standard deviations of the Laplace approximation\n",
"# Visualize the standard deviations of the final Normal distribution\n",
"sd_diag = torch.cat([v.exp().detach().cpu().flatten() for v in vi_state.log_sd_diag.values()]).numpy()\n",
"\n",
"plt.hist(sd_diag, bins=100, density=True);"
Expand Down
Empty file added tests/ekf/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions tests/ekf/test_diag_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from functools import partial
from typing import Any
import torch
from optree import tree_map

from uqlib import ekf
from uqlib.utils import diag_normal_log_prob


def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)


def test_ekf_diag():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

batch = torch.arange(3).reshape(-1, 1)

n_steps = 1000
transform = ekf.diag_fisher.build(batch_normal_log_prob_spec, lr=1e-3)

state = transform.init(init_mean)

log_liks = []

for _ in range(n_steps):
state = transform.update(state, batch)
log_liks.append(state.log_likelihood)

for key in state.mean:
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1)

# Test inplace
state_ip = transform.init(init_mean)
state_ip2 = transform.update(
state_ip,
batch,
inplace=True,
)

for key in state_ip2.mean:
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8)
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8)

# Test not inplace
state_ip_false = transform.init(
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
)
state_ip_false2 = transform.update(
state_ip_false,
batch,
inplace=False,
)

for key in state_ip.mean:
assert not torch.allclose(
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8
)
assert not torch.allclose(
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8
)
71 changes: 71 additions & 0 deletions tests/ekf/test_diag_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from functools import partial
from typing import Any
import torch
from optree import tree_map

from uqlib import ekf
from uqlib.utils import diag_normal_log_prob


def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)


def test_ekf_diag():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

batch = torch.arange(3).reshape(-1, 1)

n_steps = 1000
transform = ekf.diag_hessian.build(batch_normal_log_prob_spec, lr=1e-3)

state = transform.init(init_mean)

log_liks = []

for _ in range(n_steps):
state = transform.update(state, batch)
log_liks.append(state.log_likelihood)

for key in state.mean:
assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1)

# Test inplace
state_ip = transform.init(init_mean)
state_ip2 = transform.update(
state_ip,
batch,
inplace=True,
)

for key in state_ip2.mean:
assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8)
assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8)

# Test not inplace
state_ip_false = transform.init(
tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
)
state_ip_false2 = transform.update(
state_ip_false,
batch,
inplace=False,
)

for key in state_ip.mean:
assert not torch.allclose(
state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8
)
assert not torch.allclose(
state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8
)
Empty file added tests/laplace/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion uqlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from uqlib import ekf
from uqlib import laplace
from uqlib import vi
from uqlib import sgmcmc
from uqlib import types
from uqlib import vi

from uqlib.utils import model_to_function
from uqlib.utils import hvp
Expand All @@ -16,3 +17,5 @@
from uqlib.utils import insert_requires_grad_
from uqlib.utils import extract_requires_grad_and_func
from uqlib.utils import inplacify
from uqlib.utils import tree_map_inplacify_
from uqlib.utils import flexi_tree_map
2 changes: 2 additions & 0 deletions uqlib/ekf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from uqlib.ekf import diag_fisher
from uqlib.ekf import diag_hessian
170 changes: 170 additions & 0 deletions uqlib/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Callable, Any, NamedTuple
from functools import partial
import torch
from torch.func import vmap, jacrev
from optree import tree_map

from uqlib.types import TensorTree, Transform
from uqlib.utils import diag_normal_sample, flexi_tree_map


class EKFDiagState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.

Args:
mean: Mean of the Normal distribution.
sd_diag: Square-root diagonal of the covariance matrix of the
Normal distribution.
log_likelihood: Log likelihood of the data given the parameters.
"""

mean: TensorTree
sd_diag: TensorTree
log_likelihood: float = 0


def init(
params: TensorTree,
init_sds: TensorTree | None = None,
) -> EKFDiagState:
"""Initialise diagonal Normal distribution over parameters.

Args:
params: Initial mean of the variational distribution.
init_sds: Initial square-root diagonal of the covariance matrix
of the variational distribution. Defaults to ones.

Returns:
Initial EKFDiagState.
"""
if init_sds is None:
init_sds = tree_map(
lambda x: torch.ones_like(x, requires_grad=x.requires_grad), params
)

return EKFDiagState(params, init_sds)


def update(
state: EKFDiagState,
batch: Any,
log_likelihood: Callable[[TensorTree, Any], float],
lr: float,
transition_sd: float = 0.0,
per_sample: bool = False,
inplace: bool = True,
) -> EKFDiagState:
"""Applies an extended Kalman Filter update to the diagonal Normal distribution.
The update is first order, i.e. the likelihood is approximated by a

log p(y | x, p) ≈ log p(y | x, μ) + lr * g(μ)ᵀ(p - μ)
+ lr * 1/2 (p - μ)ᵀ F_d(μ) (p - μ) T⁻¹

where μ is the mean of the variational distribution, lr is the learning rate
(likelihood inverse temperature), whilst g(μ) is the gradient and F_d(μ) the
negative diagonal empirical Fisher of the log-likelihood with respect to the
parameters.

Args:
state: Current state.
batch: Input data to log_likelihood.
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood.
lr: Inverse temperature of the update, which behaves like a learning rate.
see https://arxiv.org/abs/1703.00209 for details.
transition_sd: Standard deviation of the transition noise, to additively
inflate the diagonal covariance before the update. Defaults to zero.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
inplace: Whether to update the state parameters in-place.

Returns:
Updated EKFDiagState.
"""

if per_sample:
log_likelihood_per_sample = log_likelihood
else:
# per-sample gradients following https://pytorch.org/tutorials/intermediate/per_sample_grads.html
@partial(vmap, in_dims=(None, 0))
def log_likelihood_per_sample(params, batch):
batch = tree_map(lambda x: x.unsqueeze(0), batch)
return log_likelihood(params, batch)

predict_sd_diag = flexi_tree_map(
lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace
)
with torch.no_grad():
log_lik = log_likelihood_per_sample(state.mean, batch).mean()
jac = jacrev(log_likelihood_per_sample)(state.mean, batch)
grad = tree_map(lambda x: x.mean(0), jac)
diag_lik_hessian_approx = tree_map(lambda x: -(x**2).mean(0), jac)

update_sd_diag = flexi_tree_map(
lambda sig, h: (sig**-2 - lr * h) ** -0.5,
predict_sd_diag,
diag_lik_hessian_approx,
inplace=inplace,
)
update_mean = flexi_tree_map(
lambda mu, sig, g: mu + sig**2 * lr * g,
state.mean,
update_sd_diag,
grad,
inplace=inplace,
)
return EKFDiagState(update_mean, update_sd_diag, log_lik.item())


def build(
log_likelihood: Callable[[TensorTree, Any], float],
lr: float,
transition_sd: float = 0.0,
per_sample: bool = False,
init_sds: TensorTree | None = None,
) -> Transform:
"""Builds a transform for variational inference with a diagonal Normal
distribution over parameters.

Args:
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood.
lr: Inverse temperature of the update, which behaves like a learning rate.
see https://arxiv.org/abs/1703.00209 for details.
transition_sd: Standard deviation of the transition noise, to additively
inflate the diagonal covariance before the update. Defaults to zero.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
init_sds: Initial square-root diagonal of the covariance matrix
of the variational distribution. Defaults to ones.

Returns:
Diagonal EKF transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, init_sds=init_sds)
update_fn = partial(
update,
log_likelihood=log_likelihood,
lr=lr,
transition_sd=transition_sd,
per_sample=per_sample,
)
return Transform(init_fn, update_fn)


def sample(state: EKFDiagState, sample_shape: torch.Size = torch.Size([])):
"""Single sample from diagonal Normal distribution over parameters.

Args:
state: State encoding mean and standard deviations.

Returns:
Sample from Normal distribution.
"""
return diag_normal_sample(state.mean, state.sd_diag, sample_shape=sample_shape)
Loading
Loading