Skip to content

Commit

Permalink
Investigate adding acquisition metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 11, 2023
1 parent d92fa07 commit 49f6ae3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 9 deletions.
46 changes: 40 additions & 6 deletions trieste/acquisition/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Generic, Mapping, Optional
from typing import Any, Callable, Generic, Mapping, Optional

from ..data import Dataset
from ..models.interfaces import ProbabilisticModelType
Expand Down Expand Up @@ -57,13 +57,15 @@ def prepare_acquisition_function(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Prepare an acquisition function. We assume that this requires at least models, but
it may sometimes also need data.
:param models: The models for each tag.
:param datasets: The data from the observer (optional).
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: An acquisition function.
"""

Expand All @@ -72,6 +74,7 @@ def update_acquisition_function(
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Update an acquisition function. By default this generates a new acquisition function each
Expand All @@ -82,6 +85,7 @@ def update_acquisition_function(
:param function: The acquisition function to update.
:param models: The models for each tag.
:param datasets: The data from the observer (optional).
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: The updated acquisition function.
"""
return self.prepare_acquisition_function(models, datasets=datasets)
Expand Down Expand Up @@ -110,19 +114,26 @@ def prepare_acquisition_function(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function(
models[tag], dataset=None if datasets is None else datasets[tag]
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function(
function, models[tag], dataset=None if datasets is None else datasets[tag]
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def __repr__(self) -> str:
Expand All @@ -135,10 +146,12 @@ def prepare_acquisition_function(
self,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param model: The model.
:param dataset: The data to use to build the acquisition function (optional).
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: An acquisition function.
"""

Expand All @@ -147,14 +160,16 @@ def update_acquisition_function(
function: AcquisitionFunction,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param function: The acquisition function to update.
:param model: The model.
:param dataset: The data from the observer (optional).
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: The updated acquisition function.
"""
return self.prepare_acquisition_function(model, dataset=dataset)
return self.prepare_acquisition_function(model, dataset=dataset, metadata=metadata)


class GreedyAcquisitionFunctionBuilder(Generic[ProbabilisticModelType], ABC):
Expand All @@ -174,6 +189,7 @@ def prepare_acquisition_function(
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Generate a new acquisition function. The first time this is called, ``pending_points``
Expand All @@ -184,6 +200,7 @@ def prepare_acquisition_function(
:param datasets: The data from the observer (optional).
:param pending_points: Points already chosen to be in the current batch (of shape [M,D]),
where M is the number of pending points and D is the search space dimension.
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: An acquisition function.
"""

Expand All @@ -194,6 +211,7 @@ def update_acquisition_function(
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Update an acquisition function. By default this generates a new acquisition function each
Expand All @@ -209,6 +227,7 @@ def update_acquisition_function(
:param new_optimization_step: Indicates whether this call to update_acquisition_function
is to start of a new optimization step, of to continue collecting batch of points
for the current step. Defaults to ``True``.
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: The updated acquisition function.
"""
return self.prepare_acquisition_function(
Expand Down Expand Up @@ -240,11 +259,13 @@ def prepare_acquisition_function(
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function(
models[tag],
dataset=None if datasets is None else datasets[tag],
pending_points=pending_points,
metadata=metadata,
)

def update_acquisition_function(
Expand All @@ -254,13 +275,15 @@ def update_acquisition_function(
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function(
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
pending_points=pending_points,
new_optimization_step=new_optimization_step,
metadata=metadata,
)

def __repr__(self) -> str:
Expand All @@ -274,12 +297,14 @@ def prepare_acquisition_function(
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param model: The model.
:param dataset: The data from the observer (optional).
:param pending_points: Points already chosen to be in the current batch (of shape [M,D]),
where M is the number of pending points and D is the search space dimension.
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: An acquisition function.
"""

Expand All @@ -290,6 +315,7 @@ def update_acquisition_function(
dataset: Optional[Dataset] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param function: The acquisition function to update.
Expand All @@ -300,6 +326,7 @@ def update_acquisition_function(
:param new_optimization_step: Indicates whether this call to update_acquisition_function
is to start of a new optimization step, of to continue collecting batch of points
for the current step. Defaults to ``True``.
:param metadata: Any metadata to pass to the acquisition function (optional).
:return: The updated acquisition function.
"""
return self.prepare_acquisition_function(
Expand Down Expand Up @@ -344,19 +371,26 @@ def prepare_acquisition_function(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function(
models[tag], dataset=None if datasets is None else datasets[tag]
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function(
function, models[tag], dataset=None if datasets is None else datasets[tag]
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit 49f6ae3

Please sign in to comment.