Skip to content

Commit

Permalink
Avoid duplicating initial points when using vectorization (#875)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Sep 20, 2024
1 parent a9aeeb1 commit 7ceb15a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def objective(x: TensorType) -> TensorType:
)


@random_seed
def _get_categorical_problem() -> SingleObjectiveTestProblem[TaggedProductSearchSpace]:
# a categorical scaled branin problem with 6 categories mapping to 3 random points
# plus the 3 minimizer points (to guarantee that the minimum is present)
Expand Down
35 changes: 31 additions & 4 deletions tests/unit/acquisition/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def target_function(x: TensorType) -> TensorType: # [N,V,D] -> [N, V]
return tf.concat(individual_func, axis=-1) # vectorize by repeating same function

optimizer = batchify_vectorize(
generate_continuous_optimizer(num_initial_samples=1_000, num_optimization_runs=10),
generate_continuous_optimizer(num_initial_samples=20_000, num_optimization_runs=10),
batch_size=vectorization,
)
maximizer = optimizer(search_space, target_function)
Expand Down Expand Up @@ -1017,10 +1017,37 @@ def test_sample_from_space(num_samples: int, batch_size: Optional[int]) -> None:
assert len(set(float(x) for batch in batches for x in batch)) == num_samples


@pytest.mark.parametrize("num_samples,batch_size", [(0, None), (-5, None), (5, 0), (5, -5)])
def test_sample_from_space_raises(num_samples: int, batch_size: Optional[int]) -> None:
@pytest.mark.parametrize(
"space", [Box([0], [1]), TaggedMultiSearchSpace([Box([0], [1]), Box([0], [1])])]
)
def test_sample_from_space_vectorization(space: SearchSpace) -> None:
batches = list(sample_from_space(10, vectorization=4)(space))
assert len(batches) == 1
assert batches[0].shape == [10, 4, 1]
assert 0 <= tf.reduce_min(batches[0]) <= tf.reduce_max(batches[0]) <= 1
# check that the vector batches aren't all the same
assert tf.reduce_any((batches[0] - batches[0][:, 0:1, :]) != 0)


@pytest.mark.parametrize(
"num_samples,batch_size,vectorization",
[(0, None, 1), (-5, None, 1), (5, 0, 1), (5, -5, 1), (5, 5, 0), (5, 5, -1)],
)
def test_sample_from_space_raises(
num_samples: int, batch_size: Optional[int], vectorization: int
) -> None:
with pytest.raises(ValueError):
sample_from_space(num_samples=num_samples, batch_size=batch_size)
sample_from_space(
num_samples=num_samples, batch_size=batch_size, vectorization=vectorization
)


def test_sample_from_space_vectorization_raises_with_invalid_space() -> None:
# vectorisation of 3 not possible with 2 subspace multisearchspace
space = TaggedMultiSearchSpace([Box([0], [1]), Box([0], [1])])
sampler = sample_from_space(10, vectorization=3)
with pytest.raises(tf.errors.InvalidArgumentError):
list(sampler(space))


def test_optimize_continuous_raises_for_insufficient_starting_points() -> None:
Expand Down
51 changes: 41 additions & 10 deletions trieste/acquisition/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,53 @@ def sampler(space: SearchSpace) -> Iterable[TensorType]:
"""


def sample_from_space(num_samples: int, batch_size: Optional[int] = None) -> InitialPointSampler:
def sample_from_space(
num_samples: int,
batch_size: Optional[int] = None,
vectorization: int = 1,
) -> InitialPointSampler:
"""
An initial point sampler that returns `num_samples` points. If `batch_size` is specified,
then these are returned in batches of that size, to preserve memory usage.
An initial point sampler that just samples from the search pace.
:param num_samples: Number of samples to return.
:param batch_size: If specified, points are return in batches of this size,
to preserve memory usage.
:param vectorization: Vectorization of the target function.
"""
if num_samples <= 0:
raise ValueError(f"num_samples must be positive, got {num_samples}")

if isinstance(batch_size, int) and batch_size <= 0:
raise ValueError(f"batch_size must be positive, got {batch_size}")

if vectorization <= 0:
raise ValueError(f"vectorization must be positive, got {vectorization}")

batch_size_int = batch_size or num_samples

def sampler(space: SearchSpace) -> Iterable[TensorType]:

# generate additional points for each vectorization (rather than just replicating them)
if isinstance(space, TaggedMultiSearchSpace):
remainder = vectorization % len(space.subspace_tags)
tf.debugging.assert_equal(
remainder,
0,
message=(
f"The vectorization of the target function {vectorization} must be a"
f"multiple of the batch shape of initial samples "
f"{len(space.subspace_tags)}."
),
)
multiple = vectorization // len(space.subspace_tags)
else:
multiple = vectorization

for offset in range(0, num_samples, batch_size_int):
yield space.sample(min(num_samples - offset, batch_size_int))
num_batch_samples = min(num_samples - offset, batch_size_int)
candidates = space.sample(num_batch_samples * multiple)
candidates = tf.reshape(candidates, [num_batch_samples, vectorization, -1])
yield candidates

return sampler

Expand Down Expand Up @@ -363,12 +394,6 @@ def generate_continuous_optimizer(
if num_recovery_runs < 0:
raise ValueError(f"num_recovery_runs must be zero or greater, got {num_recovery_runs}")

initial_sampler = (
sample_from_space(num_initial_samples)
if not callable(num_initial_samples)
else num_initial_samples
)

def optimize_continuous(
space: Box | CollectionSearchSpace,
target_func: Union[AcquisitionFunction, Tuple[AcquisitionFunction, int]],
Expand Down Expand Up @@ -400,6 +425,12 @@ def optimize_continuous(
if V <= 0:
raise ValueError(f"vectorization must be positive, got {V}")

initial_sampler = (
sample_from_space(num_initial_samples, vectorization=V)
if not callable(num_initial_samples)
else num_initial_samples
)

initial_points = generate_initial_points(
num_optimization_runs, initial_sampler, space, target_func, V
) # [num_optimization_runs,V,D]
Expand Down

0 comments on commit 7ceb15a

Please sign in to comment.