Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizers to unified API #14

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

import uqlib


def test_optim():
optimizer_cls = torch.optim.SGD
lr = 0.1

def loss_fn(p, b):
return torch.sum(p**2), torch.tensor([])

transform = uqlib.optim.build(loss_fn, optimizer_cls, lr=lr)

params = torch.tensor([1.0], requires_grad=True)
state = transform.init(params)

for _ in range(100):
state = transform.update(state, torch.tensor([1.0]))

assert state.loss < 1e-3
assert state.params < 1e-3
22 changes: 22 additions & 0 deletions tests/test_torchopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torchopt

import uqlib


def test_torchopt():
optimizer = torchopt.sgd(lr=0.1)

def loss_fn(p, b):
return torch.sum(p**2), torch.tensor([])

transform = uqlib.torchopt.build(loss_fn, optimizer)

params = torch.tensor([1.0], requires_grad=True)
state = transform.init(params)

for _ in range(100):
state = transform.update(state, torch.tensor([1.0]))

assert state.loss < 1e-3
assert state.params < 1e-3
2 changes: 2 additions & 0 deletions uqlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from uqlib import sgmcmc
from uqlib import types
from uqlib import vi
from uqlib import optim
from uqlib import torchopt

from uqlib.utils import model_to_function
from uqlib.utils import hvp
Expand Down
103 changes: 103 additions & 0 deletions uqlib/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Type, NamedTuple, Any
from functools import partial
import torch

from uqlib.types import TensorTree, Transform, LogProbFn


class OptimState(NamedTuple):
"""State of an optimizer.

Args:
params: Parameters to be optimised.
optimizer: torch.optim optimizer instance.
loss: Loss value.
aux: Auxiliary information from the loss function call.
"""

params: TensorTree
optimizer: torch.optim.Optimizer
loss: torch.tensor = torch.tensor(0.0)
aux: Any = None


def init(
params: TensorTree,
optimizer_cls: Type[torch.optim.Optimizer],
*args: Any,
**kwargs: Any,
) -> OptimState:
"""Initialise an optimizer.

Args:
params: Parameters to be optimised.
optimizer_cls: Optimizer class from torch.optim.
*args: Positional arguments to pass to the optimizer class.
**kwargs: Keyword arguments to pass to the optimizer class.

Returns:
Initial OptimState.
"""
opt_params = [params] if isinstance(params, torch.Tensor) else params

optimizer = optimizer_cls(opt_params, *args, **kwargs)
return OptimState(params, optimizer)


def update(
state: OptimState,
batch: TensorTree,
loss_fn: LogProbFn,
inplace: bool = True,
) -> OptimState:
"""Perform a single update step of the optimizer.

Args:
state: Current optimizer state.
batch: Input data to loss_fn.
loss_fn: Function that takes the parameters and returns the loss.
of the form `loss, aux = fn(params, batch)`.
inplace: Whether to update the parameters in place.
inplace=False not supported for uqlib.optim

Returns:
Updated OptimState.
"""
if not inplace:
raise NotImplementedError("inplace=False not supported for uqlib.optim")
state.optimizer.zero_grad()
loss, aux = loss_fn(state.params, batch)
loss.backward()
state.optimizer.step()
return OptimState(state.params, state.optimizer, state.loss.detach(), aux)


def build(
loss_fn: LogProbFn,
optimizer: Type[torch.optim.Optimizer],
**kwargs: Any,
) -> Transform:
"""Builds an optimizer transform from torch.optim.

Example usage:

```
transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)

for batch in dataloader:
state = transform.update(state, batch)
```

Arg:
loss_fn: Function that takes the parameters and returns the loss.
of the form `loss, aux = fn(params, batch)`.
optimizer: Optimizer class from torch.optim.
**kwargs: Keyword arguments to pass to the optimizer class.

Returns:
Optimizer transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, optimizer_cls=optimizer, **kwargs)
update_fn = partial(update, loss_fn=loss_fn)
return Transform(init_fn, update_fn)
100 changes: 100 additions & 0 deletions uqlib/torchopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import NamedTuple, Any
from functools import partial
import torch
import torchopt

from uqlib.types import TensorTree, Transform, LogProbFn


class TorchOptState(NamedTuple):
"""State of a TorchOpt optimizer.

Args:
params: Parameters to be optimised.
opt_state: TorchOpt optimizer state.
loss: Loss value.
aux: Auxiliary information from the loss function call.
"""

params: TensorTree
opt_state: torch.optim.Optimizer
loss: torch.tensor = torch.tensor(0.0)
aux: Any = None


def init(
params: TensorTree,
optimizer: torchopt.base.GradientTransformation,
) -> TorchOptState:
"""Initialise a TorchOpt optimizer.

Args:
params: Parameters to be optimised.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()

Returns:
Initial TorchOptState.
"""
opt_state = optimizer.init(params)
return TorchOptState(params, opt_state)


def update(
state: TorchOptState,
batch: TensorTree,
loss_fn: LogProbFn,
optimizer: torchopt.base.GradientTransformation,
inplace: bool = True,
) -> TorchOptState:
"""Update the optimizer state.

Args:
state: Current state.
batch: Batch of data.
loss_fn: Loss function.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()
inplace: Whether to update the state in place.

Returns:
Updated state.
"""
params = state.params
opt_state = state.opt_state
with torch.no_grad():
grads, (loss, aux) = torch.func.grad_and_value(loss_fn, has_aux=True)(
params, batch
)
updates, opt_state = optimizer.update(grads, opt_state)
params = torchopt.apply_updates(params, updates, inplace=inplace)
return TorchOptState(params, opt_state, loss, aux)


def build(
loss_fn: LogProbFn,
optimizer: torchopt.base.GradientTransformation,
) -> Transform:
"""Build a TorchOpt optimizer transformation.

Example usage:

```
transform = build(loss_fn, torchopt.adam(lr=0.1))
state = transform.init(params)

for batch in dataloader:
state = transform.update(state, batch)
```

Args:
loss_fn: Loss function.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()

Returns:
Torchopt optimizer transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, optimizer=optimizer)
update_fn = partial(update, optimizer=optimizer, loss_fn=loss_fn)
return Transform(init_fn, update_fn)
Loading