diff --git a/uqlib/__init__.py b/uqlib/__init__.py index d995314e..a9edac76 100644 --- a/uqlib/__init__.py +++ b/uqlib/__init__.py @@ -1,6 +1,7 @@ from uqlib import laplace from uqlib import vi from uqlib import sgmcmc +from uqlib import types from uqlib.utils import model_to_function from uqlib.utils import hvp diff --git a/uqlib/laplace/diag_fisher.py b/uqlib/laplace/diag_fisher.py index f5bc7eb0..9f0c6557 100644 --- a/uqlib/laplace/diag_fisher.py +++ b/uqlib/laplace/diag_fisher.py @@ -4,6 +4,7 @@ from torch.func import jacrev, vmap from optree import tree_map +from uqlib.types import TensorTree from uqlib.utils import diag_normal_sample @@ -15,13 +16,13 @@ class DiagLaplaceState(NamedTuple): prec_diag: Diagonal of the precision matrix of the Normal distribution. """ - mean: Any - prec_diag: Any + mean: TensorTree + prec_diag: TensorTree def init( - mean: Any, - init_prec_diag: Any = None, + mean: TensorTree, + init_prec_diag: TensorTree = None, ) -> DiagLaplaceState: """Initialise diagonal Normal distribution over parameters. @@ -40,7 +41,7 @@ def init( def update( state: DiagLaplaceState, - log_posterior: Callable[[Any, Any], float], + log_posterior: Callable[[TensorTree, Any], float], batch: Any, per_sample: bool = False, ) -> DiagLaplaceState: @@ -88,14 +89,16 @@ def log_posterior_per_sample(params, batch): return DiagLaplaceState(state.mean, prec_diag) -def sample(state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])): - """Single sample from diagonal Normal distribution over parameters. +def sample( + state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([]) +) -> TensorTree: + """Sample from diagonal Normal distribution over parameters. Args: state: State encoding mean and diagonal precision. Returns: - Sample from Normal distribution. + Sample(s) from Normal distribution. """ sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag) return diag_normal_sample(state.mean, sd_diag, sample_shape=sample_shape) diff --git a/uqlib/laplace/diag_hessian.py b/uqlib/laplace/diag_hessian.py index 10e1a83f..030452fb 100644 --- a/uqlib/laplace/diag_hessian.py +++ b/uqlib/laplace/diag_hessian.py @@ -2,13 +2,14 @@ import torch from optree import tree_map, tree_flatten +from uqlib.types import TensorTree from uqlib.utils import hessian_diag, diag_normal_sample from uqlib.laplace.diag_fisher import DiagLaplaceState def init( - init_mean: Any, - init_prec_diag: Any = None, + init_mean: TensorTree, + init_prec_diag: TensorTree = None, ) -> DiagLaplaceState: """Initialise diagonal Normal distribution over parameters. @@ -26,7 +27,7 @@ def init( def update( state: DiagLaplaceState, - log_posterior: Callable[[Any, Any], float], + log_posterior: Callable[[TensorTree, Any], float], batch: Any, ) -> DiagLaplaceState: """Adds diagonal negative Hessian summed across given batch. @@ -60,14 +61,16 @@ def update( return DiagLaplaceState(state.mean, batch_prec_diag) -def sample(state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([])): - """Single sample from diagonal Normal distribution over parameters. +def sample( + state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([]) +) -> TensorTree: + """Sample from diagonal Normal distribution over parameters. Args: state: State encoding mean and diagonal precision. Returns: - Sample from Normal distribution. + Sample(s) from Normal distribution. """ sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag) return diag_normal_sample(state.mean, sd_diag, sample_shape=sample_shape) diff --git a/uqlib/sgmcmc/optim/SGHMC.py b/uqlib/sgmcmc/optim/SGHMC.py index ec4b31a6..2f7b0b2d 100644 --- a/uqlib/sgmcmc/optim/SGHMC.py +++ b/uqlib/sgmcmc/optim/SGHMC.py @@ -1,7 +1,8 @@ from collections import defaultdict -from typing import Any, DefaultDict +from typing import DefaultDict from torchopt.optim.base import Optimizer +from uqlib.types import TensorTree from uqlib.sgmcmc import sghmc @@ -17,13 +18,13 @@ class SGHMC(Optimizer): def __init__( self, - params: Any, + params: TensorTree, lr: float, alpha: float = 0.01, beta: float = 0.0, temperature: float = 1.0, maximize: bool = True, - momenta: Any | None = None, + momenta: TensorTree | None = None, ) -> None: """Initialise SGHMC. diff --git a/uqlib/sgmcmc/sghmc.py b/uqlib/sgmcmc/sghmc.py index 676782e9..dabdc66c 100644 --- a/uqlib/sgmcmc/sghmc.py +++ b/uqlib/sgmcmc/sghmc.py @@ -1,9 +1,11 @@ -from typing import Any, NamedTuple, Tuple +from typing import NamedTuple, Tuple from functools import partial import torch from torchopt.base import GradientTransformation from optree import tree_map, tree_map_ +from uqlib.types import TensorTree + class SGHMCState(NamedTuple): """State enconding momenta for SGHMC. @@ -12,10 +14,10 @@ class SGHMCState(NamedTuple): momenta: Momenta for each parameter. """ - momenta: Any + momenta: TensorTree -def init(params: Any, momenta: Any | None = None) -> SGHMCState: +def init(params: TensorTree, momenta: TensorTree | None = None) -> SGHMCState: """Initialise momenta for SGHMC. Args: @@ -31,16 +33,16 @@ def init(params: Any, momenta: Any | None = None) -> SGHMCState: def update( - updates: Any, + updates: TensorTree, state: SGHMCState, lr: float, alpha: float = 0.01, beta: float = 0.0, temperature: float = 1.0, maximize: bool = True, - params: Any | None = None, + params: TensorTree | None = None, inplace: bool = True, -) -> Tuple[Any, SGHMCState]: +) -> Tuple[TensorTree, SGHMCState]: """Updates gradients and momenta for SGHMC. Args: @@ -91,7 +93,7 @@ def build( beta: float = 0.0, temperature: float = 1.0, maximize: bool = True, - momenta: Any | None = None, + momenta: TensorTree | None = None, ) -> GradientTransformation: """Builds SGHMC optimizer. diff --git a/uqlib/types.py b/uqlib/types.py new file mode 100644 index 00000000..a192a2bc --- /dev/null +++ b/uqlib/types.py @@ -0,0 +1,5 @@ +from typing import TypeAlias +from optree.typing import PyTreeTypeVar +from torch import Tensor + +TensorTree: TypeAlias = PyTreeTypeVar("TensorTree", Tensor) diff --git a/uqlib/utils.py b/uqlib/utils.py index c5e9537c..d5c76b98 100644 --- a/uqlib/utils.py +++ b/uqlib/utils.py @@ -5,15 +5,17 @@ from torch.distributions import Normal from optree import tree_map, tree_map_, tree_reduce +from uqlib.types import TensorTree -def model_to_function(model: torch.nn.Module) -> Callable[[dict, Any], Any]: + +def model_to_function(model: torch.nn.Module) -> Callable[[TensorTree, Any], Any]: """Converts a model into a function that maps parameters and inputs to outputs. Args: model: torch.nn.Module with parameters stored in .named_parameters(). Returns: - Function that takes a dictionary of parameters as well as any input + Function that takes a PyTree of parameters as well as any input arg or kwargs and returns the output of the model. """ @@ -43,7 +45,7 @@ def hvp(f: Callable, primals: tuple, tangents: tuple): def hessian_diag(f: Callable) -> Callable: - """Modify a scalar-valued function that takes a dict (with tensor values) as first + """Modify a scalar-valued function that takes a PyTree (with tensor values) as first input to return its Hessian diagonal. Inspired by https://github.com/google/jax/issues/3801 @@ -68,7 +70,7 @@ def ftemp(xtemp): def diag_normal_log_prob( - x: Any, mean: Any, sd_diag: Any, validate_args: bool = False + x: TensorTree, mean: TensorTree, sd_diag: TensorTree, validate_args: bool = False ) -> float: """Evaluate multivariate normal log probability for a diagonal covariance matrix. @@ -93,16 +95,17 @@ def diag_normal_log_prob( def diag_normal_sample( - mean: Any, sd_diag: Any, sample_shape: torch.Size = torch.Size([]) + mean: TensorTree, sd_diag: TensorTree, sample_shape: torch.Size = torch.Size([]) ) -> dict: - """Single sample from multivariate normal with diagonal covariance matrix. + """Sample from multivariate normal with diagonal covariance matrix. Args: mean: Mean of the distribution. sd_diag: Square-root diagonal of the covariance matrix. + sample_shape: Shape of the sample. Returns: - Sample from normal distribution with the same structure as mean and sd_diag. + Sample(s) from normal distribution with the same structure as mean and sd_diag. """ return tree_map( lambda m, sd: m + torch.randn(sample_shape + m.shape, device=m.device) * sd, @@ -111,7 +114,7 @@ def diag_normal_sample( ) -def tree_extract(f: Callable[[torch.tensor], bool], tree: Any) -> Any: +def tree_extract(f: Callable[[torch.tensor], bool], tree: TensorTree) -> TensorTree: """Extracts values from a PyTree where f returns True. False values are replaced with empty tensors. @@ -126,8 +129,8 @@ def tree_extract(f: Callable[[torch.tensor], bool], tree: Any) -> Any: def tree_insert( - f: Callable[[torch.tensor], bool], full_tree: Any, sub_tree: Any -) -> Any: + f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree +) -> TensorTree: """Inserts sub_tree into full_tree where full_tree tensors evaluate f to True. Both PyTrees must have the same structure. @@ -147,8 +150,8 @@ def tree_insert( def tree_insert_( - f: Callable[[torch.tensor], bool], full_tree: Any, sub_tree: Any -) -> Any: + f: Callable[[torch.tensor], bool], full_tree: TensorTree, sub_tree: TensorTree +) -> TensorTree: """Inserts sub_tree into full_tree in-place where full_tree tensors evaluate f to True. Both PyTrees must have the same structure. @@ -168,7 +171,7 @@ def insert_(full, sub): return tree_map_(insert_, full_tree, sub_tree) -def extract_requires_grad(tree: Any) -> Any: +def extract_requires_grad(tree: TensorTree) -> TensorTree: """Extracts only parameters that require gradients. Args: @@ -180,7 +183,7 @@ def extract_requires_grad(tree: Any) -> Any: return tree_extract(lambda x: x.requires_grad, tree) -def insert_requires_grad(full_tree: Any, sub_tree: Any) -> Any: +def insert_requires_grad(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree: """Inserts sub_tree into full_tree where full_tree tensors requires_grad. Both PyTrees must have the same structure. @@ -194,7 +197,7 @@ def insert_requires_grad(full_tree: Any, sub_tree: Any) -> Any: return tree_insert(lambda x: x.requires_grad, full_tree, sub_tree) -def insert_requires_grad_(full_tree: Any, sub_tree: Any) -> Any: +def insert_requires_grad_(full_tree: TensorTree, sub_tree: TensorTree) -> TensorTree: """Inserts sub_pytree into full_tree in-place where full_tree tensors requires_grad. Both PyTrees must have the same structure. @@ -209,8 +212,8 @@ def insert_requires_grad_(full_tree: Any, sub_tree: Any) -> Any: def extract_requires_grad_and_func( - tree: Any, func: Callable, inplace: bool = False -) -> Tuple[Any, Callable]: + tree: TensorTree, func: Callable, inplace: bool = False +) -> Tuple[TensorTree, Callable]: """Extracts only parameters that require gradients and converts a function that takes the full parameter tree (in its first argument) into one that takes the subtree. diff --git a/uqlib/vi/diag.py b/uqlib/vi/diag.py index 7b710670..cc2b366e 100644 --- a/uqlib/vi/diag.py +++ b/uqlib/vi/diag.py @@ -4,6 +4,7 @@ from optree import tree_map import torchopt +from uqlib.types import TensorTree from uqlib.utils import diag_normal_log_prob, diag_normal_sample @@ -19,16 +20,16 @@ class VIDiagState(NamedTuple): nelbo: Negative evidence lower bound (lower is better). """ - mean: Any - log_sd_diag: Any + mean: TensorTree + log_sd_diag: TensorTree optimizer_state: tuple nelbo: float = 0 def init( - init_mean: Any, + init_mean: TensorTree, optimizer: torchopt.base.GradientTransformation, - init_log_sds: Any = None, + init_log_sds: TensorTree | None = None, ) -> VIDiagState: """Initialise diagonal Normal variational distribution over parameters. @@ -65,7 +66,7 @@ def init( def update( state: VIDiagState, - log_posterior: Callable[[Any, Any], float], + log_posterior: Callable[[TensorTree, Any], float], batch: Any, optimizer: torchopt.base.GradientTransformation, temperature: float = 1.0, @@ -119,7 +120,7 @@ def update( def nelbo( mean: dict, sd_diag: dict, - log_posterior: Callable[[Any, Any], float], + log_posterior: Callable[[TensorTree, Any], float], batch: Any, temperature: float = 1.0, n_samples: int = 1,