Skip to content

Commit

Permalink
Categorical Trust Regions
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 31, 2024
1 parent e6a0692 commit 003d8a7
Showing 1 changed file with 51 additions and 65 deletions.
116 changes: 51 additions & 65 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
from ..space import (
Box,
DiscreteSearchSpace,
DiscreteSearchSpaceType,
GeneralDiscreteSearchSpace,
HasOneHotEncoder,
SearchSpace,
TaggedMultiSearchSpace,
TaggedProductSearchSpace,
Expand Down Expand Up @@ -1625,7 +1626,13 @@ def __init__(
self._y_min = tf.constant(np.inf, dtype=self.location.dtype)

def _init_eps(self) -> None:
self.eps = self._zeta * (self.global_search_space.upper - self.global_search_space.lower)
if isinstance(self.global_search_space, HasOneHotEncoder):
# categorical space distance is hardcoded to a (Hamming) distance of 1
self.eps = tf.constant(1)
else:
self.eps = self._zeta * (
self.global_search_space.upper - self.global_search_space.lower
)

@abstractmethod
def _update_domain(self) -> None:
Expand Down Expand Up @@ -1692,9 +1699,12 @@ def update(
datasets = self.select_in_region(datasets) # See `select_in_region` comment above.
x_min, y_min = self.get_dataset_min(datasets)

tr_volume = tf.reduce_prod(self.upper - self.lower)
self._step_is_success = y_min < self._y_min - self._kappa * tr_volume
self.eps = self.eps / self._beta if self._step_is_success else self.eps * self._beta
if isinstance(self.global_search_space, HasOneHotEncoder):
self._step_is_success = y_min < self._y_min
else:
tr_volume = tf.reduce_prod(self.upper - self.lower)
self._step_is_success = y_min < self._y_min - self._kappa * tr_volume
self.eps = self.eps / self._beta if self._step_is_success else self.eps * self._beta

# Only update the location if the step was successful.
if self._step_is_success:
Expand Down Expand Up @@ -2219,16 +2229,14 @@ def get_dataset_min(
return tf.squeeze(x_min, axis=0), tf.squeeze(y_min)


class UpdatableTrustRegionDiscreteMixin(UpdatableTrustRegion, Generic[DiscreteSearchSpaceType]):
class UpdatableTrustRegionDiscrete(DiscreteSearchSpace, UpdatableTrustRegion):
"""
A mixin for updatable discrete search spaces with an associated global search space.
Specific subclasses define distance metrics appropriate for specific types of discrete
spaces.
An updatable discrete search space with an associated global search space.
"""

def __init__(
self,
global_search_space: DiscreteSearchSpaceType,
global_search_space: GeneralDiscreteSearchSpace,
region_index: Optional[int] = None,
input_active_dims: Optional[Union[slice, Sequence[int]]] = None,
):
Expand All @@ -2240,7 +2248,8 @@ def __init__(
:param input_active_dims: The active dimensions of the input space, either a slice or list
of indices into the columns of the space. If `None`, all dimensions are active.
"""
# subclasses should also initialise this using the global_search_space
# Ensure global_points is a copied tensor, in case a variable is passed in.
DiscreteSearchSpace.__init__(self, tf.constant(global_search_space.points))
UpdatableTrustRegion.__init__(self, region_index, input_active_dims)
self._global_search_space = global_search_space

Expand All @@ -2258,19 +2267,36 @@ def location(self, location: TensorType) -> None:
self._location_ix = tf.squeeze(location_ix, axis=-1)

@property
def global_search_space(self) -> DiscreteSearchSpaceType:
def global_search_space(self) -> GeneralDiscreteSearchSpace:
return self._global_search_space

def _compute_global_distances(self) -> TensorType:
# Helper method to compute and return pairwise distances along each axis in the
# global search space.

points = self.global_search_space.points
if isinstance(self.global_search_space, HasOneHotEncoder):
# use Hamming distance for categorical spaces
return tf.math.reduce_sum(
tf.where(tf.expand_dims(points, -2) == tf.expand_dims(points, -3), 0, 1),
axis=-1,
keepdims=True, # (keep last dim for distance calculation below)
) # [num_points, num_points, 1]
else:
return tf.abs(
tf.expand_dims(points, -2) - tf.expand_dims(points, -3)
) # [num_points, num_points, D]

def _get_points_within_distance(
self, global_distances: TensorType, eps: TensorType
self, global_distances: TensorType, distance: TensorType
) -> TensorType:
# Helper method to return subset of global points within a given `eps` distance of the
# Helper method to return subset of global points within a given distance of the
# region location. Takes the precomputed pairwise distances tensor and the trust region
# size `eps`.
# size `eps` (or a hard-coded value of 1 in the case of categorical spaces).

# Indices of the neighbors within the trust region.
neighbors_mask = tf.reduce_all(
tf.gather(global_distances, self._location_ix) <= eps, axis=-1
tf.gather(global_distances, self._location_ix) <= distance, axis=-1
)
neighbors_mask = tf.reshape(neighbors_mask, (-1,))
neighbor_ixs = tf.where(neighbors_mask)
Expand All @@ -2279,44 +2305,22 @@ def _get_points_within_distance(
return tf.gather(self.global_search_space.points, neighbor_ixs)


class UpdatableTrustRegionDiscrete(
DiscreteSearchSpace, UpdatableTrustRegionDiscreteMixin[DiscreteSearchSpace]
):

class FixedPointTrustRegionDiscrete(UpdatableTrustRegionDiscrete):
"""
An updatable discrete search space with an associated global search space.
A discrete trust region with a fixed point location that does not change across active learning
steps. The fixed point is selected at random from the global (discrete) search space at
initialization time.
"""

def __init__(
self,
global_search_space: DiscreteSearchSpace,
global_search_space: GeneralDiscreteSearchSpace,
region_index: Optional[int] = None,
input_active_dims: Optional[Union[slice, Sequence[int]]] = None,
):
# Ensure global_points is a copied tensor, in case a variable is passed in.
DiscreteSearchSpace.__init__(self, tf.constant(global_search_space.points))
UpdatableTrustRegionDiscreteMixin.__init__(
self, global_search_space, region_index, input_active_dims
)

def _compute_global_distances(self) -> TensorType:
# Helper method to compute and return pairwise distances along each axis in the
# global search space.
points = self.global_search_space.points
return tf.abs(
tf.expand_dims(points, -2) - tf.expand_dims(points, -3)
) # [num_points, num_points, D]


# TODO: define UpdatableTrustRegionCategorical


class FixedPointTrustRegionMixin(UpdatableTrustRegion):
"""
A mixin for discrete trust regions with fixed point locations that do not change across
active learning steps. The fixed point is selected at random from the global (discrete)
search space at initialization time.
"""
super().__init__(global_search_space, region_index, input_active_dims)
# Random initial point from the global search space.
self._init_location()

def initialize(
self,
Expand All @@ -2336,21 +2340,6 @@ def update(
pass


class FixedPointTrustRegionDiscrete(UpdatableTrustRegionDiscrete, FixedPointTrustRegionMixin):
def __init__(
self,
global_search_space: DiscreteSearchSpace,
region_index: Optional[int] = None,
input_active_dims: Optional[Union[slice, Sequence[int]]] = None,
):
super().__init__(global_search_space, region_index, input_active_dims)
# Random initial point from the global search space.
self._init_location()


# TODO: define FixedPointTrustRegionCategorical


class SingleObjectiveTrustRegionDiscrete(UpdatableTrustRegionDiscrete, HypercubeTrustRegion):
"""
An updatable discrete trust region that maintains a set of neighboring points around a
Expand All @@ -2370,7 +2359,7 @@ class SingleObjectiveTrustRegionDiscrete(UpdatableTrustRegionDiscrete, Hypercube

def __init__(
self,
global_search_space: DiscreteSearchSpace,
global_search_space: GeneralDiscreteSearchSpace,
beta: float = 0.7,
kappa: float = 1e-4,
zeta: float = 0.5,
Expand Down Expand Up @@ -2408,9 +2397,6 @@ def _update_domain(self) -> None:
self._points = self._get_points_within_distance(self._global_distances, self.eps)


# TODO: define SingleObjectiveTrustRegionCategorical


class UpdatableTrustRegionProduct(TaggedProductSearchSpace, UpdatableTrustRegion):
"""
An updatable mixed search space that is the product of multiple updatable trust sub-regions.
Expand Down

0 comments on commit 003d8a7

Please sign in to comment.