Skip to content

Commit

Permalink
Merge pull request #7 from normal-computing/types
Browse files Browse the repository at this point in the history
Add TensorTree type
  • Loading branch information
SamDuffield authored Jan 22, 2024
2 parents d156cc3 + e1e09b8 commit ac07f37
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 47 deletions.
1 change: 1 addition & 0 deletions uqlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from uqlib import laplace
from uqlib import vi
from uqlib import sgmcmc
from uqlib import types

from uqlib.utils import model_to_function
from uqlib.utils import hvp
Expand Down
19 changes: 11 additions & 8 deletions uqlib/laplace/diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.func import jacrev, vmap
from optree import tree_map

from uqlib.types import TensorTree
from uqlib.utils import diag_normal_sample


Expand All @@ -15,13 +16,13 @@ class DiagLaplaceState(NamedTuple):
prec_diag: Diagonal of the precision matrix of the Normal distribution.
"""

mean: Any
prec_diag: Any
mean: TensorTree
prec_diag: TensorTree


def init(
mean: Any,
init_prec_diag: Any = None,
mean: TensorTree,
init_prec_diag: TensorTree = None,
) -> DiagLaplaceState:
"""Initialise diagonal Normal distribution over parameters.
Expand All @@ -40,7 +41,7 @@ def init(

def update(
state: DiagLaplaceState,
log_posterior: Callable[[Any, Any], float],
log_posterior: Callable[[TensorTree, Any], float],
batch: Any,
per_sample: bool = False,
) -> DiagLaplaceState:
Expand Down Expand Up @@ -88,14 +89,16 @@ def log_posterior_per_sample(params, batch):
return DiagLaplaceState(state.mean, prec_diag)


def sample(state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])):
"""Single sample from diagonal Normal distribution over parameters.
def sample(
state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
"""Sample from diagonal Normal distribution over parameters.
Args:
state: State encoding mean and diagonal precision.
Returns:
Sample from Normal distribution.
Sample(s) from Normal distribution.
"""
sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag)
return diag_normal_sample(state.mean, sd_diag, sample_shape=sample_shape)
15 changes: 9 additions & 6 deletions uqlib/laplace/diag_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from optree import tree_map, tree_flatten

from uqlib.types import TensorTree
from uqlib.utils import hessian_diag, diag_normal_sample
from uqlib.laplace.diag_fisher import DiagLaplaceState


def init(
init_mean: Any,
init_prec_diag: Any = None,
init_mean: TensorTree,
init_prec_diag: TensorTree = None,
) -> DiagLaplaceState:
"""Initialise diagonal Normal distribution over parameters.
Expand All @@ -26,7 +27,7 @@ def init(

def update(
state: DiagLaplaceState,
log_posterior: Callable[[Any, Any], float],
log_posterior: Callable[[TensorTree, Any], float],
batch: Any,
) -> DiagLaplaceState:
"""Adds diagonal negative Hessian summed across given batch.
Expand Down Expand Up @@ -60,14 +61,16 @@ def update(
return DiagLaplaceState(state.mean, batch_prec_diag)


def sample(state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])):
"""Single sample from diagonal Normal distribution over parameters.
def sample(
state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
"""Sample from diagonal Normal distribution over parameters.
Args:
state: State encoding mean and diagonal precision.
Returns:
Sample from Normal distribution.
Sample(s) from Normal distribution.
"""
sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag)
return diag_normal_sample(state.mean, sd_diag, sample_shape=sample_shape)
7 changes: 4 additions & 3 deletions uqlib/sgmcmc/optim/SGHMC.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import defaultdict
from typing import Any, DefaultDict
from typing import DefaultDict
from torchopt.optim.base import Optimizer

from uqlib.types import TensorTree
from uqlib.sgmcmc import sghmc


Expand All @@ -17,13 +18,13 @@ class SGHMC(Optimizer):

def __init__(
self,
params: Any,
params: TensorTree,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
maximize: bool = True,
momenta: Any | None = None,
momenta: TensorTree | None = None,
) -> None:
"""Initialise SGHMC.
Expand Down
16 changes: 9 additions & 7 deletions uqlib/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, NamedTuple, Tuple
from typing import NamedTuple, Tuple
from functools import partial
import torch
from torchopt.base import GradientTransformation
from optree import tree_map, tree_map_

from uqlib.types import TensorTree


class SGHMCState(NamedTuple):
"""State enconding momenta for SGHMC.
Expand All @@ -12,10 +14,10 @@ class SGHMCState(NamedTuple):
momenta: Momenta for each parameter.
"""

momenta: Any
momenta: TensorTree


def init(params: Any, momenta: Any | None = None) -> SGHMCState:
def init(params: TensorTree, momenta: TensorTree | None = None) -> SGHMCState:
"""Initialise momenta for SGHMC.
Args:
Expand All @@ -31,16 +33,16 @@ def init(params: Any, momenta: Any | None = None) -> SGHMCState:


def update(
updates: Any,
updates: TensorTree,
state: SGHMCState,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
maximize: bool = True,
params: Any | None = None,
params: TensorTree | None = None,
inplace: bool = True,
) -> Tuple[Any, SGHMCState]:
) -> Tuple[TensorTree, SGHMCState]:
"""Updates gradients and momenta for SGHMC.
Args:
Expand Down Expand Up @@ -91,7 +93,7 @@ def build(
beta: float = 0.0,
temperature: float = 1.0,
maximize: bool = True,
momenta: Any | None = None,
momenta: TensorTree | None = None,
) -> GradientTransformation:
"""Builds SGHMC optimizer.
Expand Down
5 changes: 5 additions & 0 deletions uqlib/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import TypeAlias
from optree.typing import PyTreeTypeVar
from torch import Tensor

TensorTree: TypeAlias = PyTreeTypeVar("TensorTree", Tensor)
37 changes: 20 additions & 17 deletions uqlib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce

from uqlib.types import TensorTree

def model_to_function(model: torch.nn.Module) -> Callable[[dict, Any], Any]:

def model_to_function(model: torch.nn.Module) -> Callable[[TensorTree, Any], Any]:
"""Converts a model into a function that maps parameters and inputs to outputs.
Args:
model: torch.nn.Module with parameters stored in .named_parameters().
Returns:
Function that takes a dictionary of parameters as well as any input
Function that takes a PyTree of parameters as well as any input
arg or kwargs and returns the output of the model.
"""

Expand Down Expand Up @@ -43,7 +45,7 @@ def hvp(f: Callable, primals: tuple, tangents: tuple):


def hessian_diag(f: Callable) -> Callable:
"""Modify a scalar-valued function that takes a dict (with tensor values) as first
"""Modify a scalar-valued function that takes a PyTree (with tensor values) as first
input to return its Hessian diagonal.
Inspired by https://github.com/google/jax/issues/3801
Expand All @@ -68,7 +70,7 @@ def ftemp(xtemp):


def diag_normal_log_prob(
x: Any, mean: Any, sd_diag: Any, validate_args: bool = False
x: TensorTree, mean: TensorTree, sd_diag: TensorTree, validate_args: bool = False
) -> float:
"""Evaluate multivariate normal log probability for a diagonal covariance matrix.
Expand All @@ -93,16 +95,17 @@ def diag_normal_log_prob(


def diag_normal_sample(
mean: Any, sd_diag: Any, sample_shape: torch.Size = torch.Size([])
mean: TensorTree, sd_diag: TensorTree, sample_shape: torch.Size = torch.Size([])
) -> dict:
"""Single sample from multivariate normal with diagonal covariance matrix.
"""Sample from multivariate normal with diagonal covariance matrix.
Args:
mean: Mean of the distribution.
sd_diag: Square-root diagonal of the covariance matrix.
sample_shape: Shape of the sample.
Returns:
Sample from normal distribution with the same structure as mean and sd_diag.
Sample(s) from normal distribution with the same structure as mean and sd_diag.
"""
return tree_map(
lambda m, sd: m + torch.randn(sample_shape + m.shape, device=m.device) * sd,
Expand All @@ -111,7 +114,7 @@ def diag_normal_sample(
)


def tree_extract(f: Callable[[torch.tensor], bool], tree: Any) -> Any:
def tree_extract(f: Callable[[torch.tensor], bool], tree: TensorTree) -> TensorTree:
"""Extracts values from a PyTree where f returns True.
False values are replaced with empty tensors.
Expand All @@ -126,8 +129,8 @@ def tree_extract(f: Callable[[torch.tensor], bool], tree: Any) -> Any:


def tree_insert(
f: Callable[[torch.tensor], bool], full_tree: Any, sub_tree: Any
) -> Any:
f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree
) -> TensorTree:
"""Inserts sub_tree into full_tree where full_tree tensors evaluate f to True.
Both PyTrees must have the same structure.
Expand All @@ -147,8 +150,8 @@ def tree_insert(


def tree_insert_(
f: Callable[[torch.tensor], bool], full_tree: Any, sub_tree: Any
) -> Any:
f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree
) -> TensorTree:
"""Inserts sub_tree into full_tree in-place where full_tree tensors evaluate
f to True. Both PyTrees must have the same structure.
Expand All @@ -168,7 +171,7 @@ def insert_(full, sub):
return tree_map_(insert_, full_tree, sub_tree)


def extract_requires_grad(tree: Any) -> Any:
def extract_requires_grad(tree: TensorTree) -> TensorTree:
"""Extracts only parameters that require gradients.
Args:
Expand All @@ -180,7 +183,7 @@ def extract_requires_grad(tree: Any) -> Any:
return tree_extract(lambda x: x.requires_grad, tree)


def insert_requires_grad(full_tree: Any, sub_tree: Any) -> Any:
def insert_requires_grad(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
"""Inserts sub_tree into full_tree where full_tree tensors requires_grad.
Both PyTrees must have the same structure.
Expand All @@ -194,7 +197,7 @@ def insert_requires_grad(full_tree: Any, sub_tree: Any) -> Any:
return tree_insert(lambda x: x.requires_grad, full_tree, sub_tree)


def insert_requires_grad_(full_tree: Any, sub_tree: Any) -> Any:
def insert_requires_grad_(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree:
"""Inserts sub_pytree into full_tree in-place where full_tree tensors requires_grad.
Both PyTrees must have the same structure.
Expand All @@ -209,8 +212,8 @@ def insert_requires_grad_(full_tree: Any, sub_tree: Any) -> Any:


def extract_requires_grad_and_func(
tree: Any, func: Callable, inplace: bool = False
) -> Tuple[Any, Callable]:
tree: TensorTree, func: Callable, inplace: bool = False
) -> Tuple[TensorTree, Callable]:
"""Extracts only parameters that require gradients and converts a function
that takes the full parameter tree (in its first argument)
into one that takes the subtree.
Expand Down
13 changes: 7 additions & 6 deletions uqlib/vi/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from optree import tree_map
import torchopt

from uqlib.types import TensorTree
from uqlib.utils import diag_normal_log_prob, diag_normal_sample


Expand All @@ -19,16 +20,16 @@ class VIDiagState(NamedTuple):
nelbo: Negative evidence lower bound (lower is better).
"""

mean: Any
log_sd_diag: Any
mean: TensorTree
log_sd_diag: TensorTree
optimizer_state: tuple
nelbo: float = 0


def init(
init_mean: Any,
init_mean: TensorTree,
optimizer: torchopt.base.GradientTransformation,
init_log_sds: Any = None,
init_log_sds: TensorTree | None = None,
) -> VIDiagState:
"""Initialise diagonal Normal variational distribution over parameters.
Expand Down Expand Up @@ -65,7 +66,7 @@ def init(

def update(
state: VIDiagState,
log_posterior: Callable[[Any, Any], float],
log_posterior: Callable[[TensorTree, Any], float],
batch: Any,
optimizer: torchopt.base.GradientTransformation,
temperature: float = 1.0,
Expand Down Expand Up @@ -119,7 +120,7 @@ def update(
def nelbo(
mean: dict,
sd_diag: dict,
log_posterior: Callable[[Any, Any], float],
log_posterior: Callable[[TensorTree, Any], float],
batch: Any,
temperature: float = 1.0,
n_samples: int = 1,
Expand Down

0 comments on commit ac07f37

Please sign in to comment.