From 7de44678f1312e01b372aaec4ebdd0c3437fc72c Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Fri, 25 Oct 2024 06:12:44 -0700 Subject: [PATCH] Introduce `SurrogateTestFunction` (#2953) Summary: **Context**: This next diff will cut over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity. **Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming. The name changes I hope to make: * ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner * ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate. * BoTorchTestProblem -> BoTorchTestFunction **Changes in this diff**: * Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates, giving it the surrogate-related logic from SurrogateRunner Reviewed By: saitcakmak Differential Revision: D64899032 --- ax/benchmark/runners/surrogate.py | 80 +++++++++++++++++ .../runners/test_botorch_test_problem.py | 85 +++++++++++++------ .../tests/runners/test_surrogate_runner.py | 78 ++++++++++++++++- ax/utils/testing/benchmark_stubs.py | 60 ++++++++++++- 4 files changed, 270 insertions(+), 33 deletions(-) diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 42c39a1bbbf..3a8bf83ba4c 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -11,9 +11,11 @@ import torch from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures from ax.core.search_space import SearchSpaceDigest +from ax.core.types import TParamValue from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker @@ -22,6 +24,84 @@ from torch import Tensor +@dataclass(kw_only=True) +class SurrogateTestFunction(ParamBasedTestProblem): + """ + Data-generating function for surrogate benchmark problems. + + Args: + name: The name of the runner. + outcome_names: Names of outcomes to return in `evaluate_true`, if the + surrogate produces more outcomes than are needed. + _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use + for generating observations. If `None`, `get_surrogate_and_datasets` + must not be None and will be used to generate the surrogate when it + is needed. + _datasets: Either `None`, or the `SupervisedDataset`s used to fit + the surrogate model. If `None`, `get_surrogate_and_datasets` must + not be None and will be used to generate the datasets when they are + needed. + get_surrogate_and_datasets: Function that returns the surrogate and + datasets, to allow for lazy construction. If + `get_surrogate_and_datasets` is not provided, `surrogate` and + `datasets` must be provided, and vice versa. + """ + + name: str + outcome_names: list[str] + _surrogate: TorchModelBridge | None = None + _datasets: list[SupervisedDataset] | None = None + get_surrogate_and_datasets: ( + None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] + ) = None + + def __post_init__(self) -> None: + if self.get_surrogate_and_datasets is None and ( + self._surrogate is None or self._datasets is None + ): + raise ValueError( + "If `get_surrogate_and_datasets` is None, `_surrogate` " + "and `_datasets` must not be None, and vice versa." + ) + + def set_surrogate_and_datasets(self) -> None: + self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() + + @property + def surrogate(self) -> TorchModelBridge: + if self._surrogate is None: + self.set_surrogate_and_datasets() + return none_throws(self._surrogate) + + @property + def datasets(self) -> list[SupervisedDataset]: + if self._datasets is None: + self.set_surrogate_and_datasets() + return none_throws(self._datasets) + + def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: + # We're ignoring the uncertainty predictions of the surrogate model here and + # use the mean predictions as the outcomes (before potentially adding noise) + means, _ = self.surrogate.predict( + # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict + observation_features=[ObservationFeatures(params)] + ) + means = [means[name][0] for name in self.outcome_names] + return torch.tensor( + means, + device=self.surrogate.device, + dtype=self.surrogate.dtype, + ) + + @equality_typechecker + def __eq__(self, other: Base) -> bool: + if type(other) is not type(self): + return False + + # Don't check surrogate, datasets, or callable + return self.name == other.name + + @dataclass class SurrogateRunner(BenchmarkRunner): """Runner for surrogate benchmark problems. diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index 8f89f2105c1..c32c9a30b7c 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -7,9 +7,10 @@ # pyre-strict +from contextlib import nullcontext from dataclasses import replace from itertools import product -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np @@ -19,13 +20,17 @@ BoTorchTestProblem, ParamBasedTestProblemRunner, ) +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.trial import Trial from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast -from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_test_function, + TestParamBasedTestProblem, +) from botorch.test_functions.multi_objective import BraninCurrin from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann from botorch.utils.transforms import normalize @@ -130,34 +135,35 @@ def test_synthetic_runner(self) -> None: for num_outcomes in (1, 2) for noise_std in (0.0, [float(i) for i in range(num_outcomes)]) ] - for test_problem, noise_std, num_outcomes in botorch_cases + param_based_cases: - is_constrained = isinstance( - test_problem, BoTorchTestProblem - ) and isinstance(test_problem.botorch_problem, ConstrainedHartmann) - num_constraints = 1 if is_constrained else 0 - outcome_names = [ - f"objective_{i}" for i in range(num_outcomes - num_constraints) - ] + ["constraint"] * num_constraints + surrogate_cases = [ + (get_soo_surrogate_test_function(lazy=False), noise_std, 1) + for noise_std in (0.0, 1.0, [0.0], [1.0]) + ] + for test_problem, noise_std, num_outcomes in ( + botorch_cases + param_based_cases + surrogate_cases + ): + # Set up outcome names + if isinstance(test_problem, BoTorchTestProblem): + if isinstance(test_problem.botorch_problem, ConstrainedHartmann): + outcome_names = ["objective_0", "constraint"] + else: + outcome_names = ["objective_0"] + elif isinstance(test_problem, TestParamBasedTestProblem): + outcome_names = [f"objective_{i}" for i in range(num_outcomes)] + else: # SurrogateTestFunction + outcome_names = ["branin"] + # Set up runner runner = ParamBasedTestProblemRunner( test_problem=test_problem, outcome_names=outcome_names, noise_std=noise_std, ) - modified_bounds = ( - test_problem.modified_bounds - if isinstance(test_problem, BoTorchTestProblem) - else None - ) - - test_description: str = ( - f"test problem: {test_problem.__class__.__name__}, " - f"modified_bounds: {modified_bounds}, " - f"noise_std: {noise_std}." - ) - is_botorch = isinstance(test_problem, BoTorchTestProblem) - with self.subTest(f"Test basic construction, {test_description}"): + test_description = f"{test_problem=}, {noise_std=}" + with self.subTest( + f"Test basic construction, {test_problem=}, {noise_std=}" + ): self.assertIs(runner.test_problem, test_problem) self.assertEqual(runner.outcome_names, outcome_names) if isinstance(noise_std, list): @@ -183,6 +189,7 @@ def test_synthetic_runner(self) -> None: test_problem.botorch_problem.bounds.dtype, torch.double ) + is_botorch = isinstance(test_problem, BoTorchTestProblem) with self.subTest(f"test `get_Y_true()`, {test_description}"): dim = 6 if is_botorch else 9 X = torch.rand(1, dim, dtype=torch.double) @@ -195,7 +202,20 @@ def test_synthetic_runner(self) -> None: ) params = dict(zip(param_names, (x.item() for x in X.unbind(-1)))) - Y = runner.get_Y_true(params=params) + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"branin": [4.2]}, None), + ) + ): + Y = runner.get_Y_true(params=params) + oracle = runner.evaluate_oracle(parameters=params) + if ( isinstance(test_problem, BoTorchTestProblem) and test_problem.modified_bounds is not None @@ -221,12 +241,13 @@ def test_synthetic_runner(self) -> None: ) else: expected_Y = obj + elif isinstance(test_problem, SurrogateTestFunction): + expected_Y = torch.tensor([4.2], dtype=torch.double) else: expected_Y = torch.full( torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double ) self.assertTrue(torch.allclose(Y, expected_Y)) - oracle = runner.evaluate_oracle(parameters=params) self.assertTrue(np.equal(Y.numpy(), oracle).all()) with self.subTest(f"test `run()`, {test_description}"): @@ -237,7 +258,19 @@ def test_synthetic_runner(self) -> None: trial.arms = [arm] trial.arm = arm trial.index = 0 - res = runner.run(trial=trial) + + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"branin": [4.2]}, None), + ) + ): + res = runner.run(trial=trial) self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys()) self.assertEqual({"0_0"}, res["Ys"].keys()) diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index e0046c32433..099d09f2d5e 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,12 +8,82 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import get_soo_surrogate +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_legacy, + get_soo_surrogate_test_function, +) + + +class TestSurrogateTestFunction(TestCase): + def test_surrogate_test_function(self) -> None: + # Construct a search space with log-scale parameters. + for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): + with self.subTest(noise_std=noise_std): + surrogate = MagicMock() + mock_mean = torch.tensor([[0.1234]], dtype=torch.double) + surrogate.predict = MagicMock(return_value=(mock_mean, 0)) + surrogate.device = torch.device("cpu") + surrogate.dtype = torch.double + test_function = SurrogateTestFunction( + name="test test function", + outcome_names=["dummy metric"], + _surrogate=surrogate, + _datasets=[], + ) + self.assertEqual(test_function.name, "test test function") + self.assertIs(test_function.surrogate, surrogate) + + def test_lazy_instantiation(self) -> None: + test_function = get_soo_surrogate_test_function() + + self.assertIsNone(test_function._surrogate) + self.assertIsNone(test_function._datasets) + + # Accessing `surrogate` sets datasets and surrogate + self.assertIsInstance(test_function.surrogate, TorchModelBridge) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + # Accessing `datasets` also sets datasets and surrogate + test_function = get_soo_surrogate_test_function() + self.assertIsInstance(test_function.datasets, list) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + with patch.object( + test_function, + "get_surrogate_and_datasets", + wraps=test_function.get_surrogate_and_datasets, + ) as mock_get_surrogate_and_datasets: + test_function.surrogate + mock_get_surrogate_and_datasets.assert_not_called() + + def test_instantiation_raises_with_missing_args(self) -> None: + with self.assertRaisesRegex( + ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " + ): + SurrogateTestFunction(name="test runner", outcome_names=[]) + + def test_equality(self) -> None: + def _construct_test_function(name: str) -> SurrogateTestFunction: + return SurrogateTestFunction( + name=name, + _surrogate=MagicMock(), + _datasets=[], + outcome_names=["dummy_metric"], + ) + + runner_1 = _construct_test_function("test 1") + runner_2 = _construct_test_function("test 2") + runner_1a = _construct_test_function("test 1") + self.assertEqual(runner_1, runner_1a) + self.assertNotEqual(runner_1, runner_2) + self.assertNotEqual(runner_1, 1) class TestSurrogateRunner(TestCase): @@ -49,7 +119,7 @@ def test_surrogate_runner(self) -> None: self.assertEqual(runner.noise_stds, noise_std) def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsNone(runner._surrogate) self.assertIsNone(runner._datasets) @@ -60,7 +130,7 @@ def test_lazy_instantiation(self) -> None: self.assertIsInstance(runner._datasets, list) # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsInstance(runner.datasets, list) self.assertIsInstance(runner._surrogate, TorchModelBridge) self.assertIsInstance(runner._datasets, list) diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 0c1a37c5a47..60244511b6b 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -16,8 +16,11 @@ from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem -from ax.benchmark.runners.botorch_test import ParamBasedTestProblem -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.botorch_test import ( + ParamBasedTestProblem, + ParamBasedTestProblemRunner, +) +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -75,7 +78,58 @@ def get_multi_objective_benchmark_problem( ) -def get_soo_surrogate() -> SurrogateBenchmarkProblem: +def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: + experiment = get_branin_experiment(with_completed_trial=True) + surrogate = TorchModelBridge( + experiment=experiment, + search_space=experiment.search_space, + model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + data=experiment.lookup_data(), + transforms=[], + ) + if lazy: + test_function = SurrogateTestFunction( + outcome_names=["branin"], + name="test", + get_surrogate_and_datasets=lambda: (surrogate, []), + ) + else: + test_function = SurrogateTestFunction( + outcome_names=["branin"], + name="test", + _surrogate=surrogate, + _datasets=[], + ) + return test_function + + +def get_soo_surrogate() -> BenchmarkProblem: + experiment = get_branin_experiment(with_completed_trial=True) + test_function = get_soo_surrogate_test_function() + runner = ParamBasedTestProblemRunner( + test_problem=test_function, outcome_names=["branin"] + ) + + observe_noise_sd = True + objective = Objective( + metric=BenchmarkMetric( + name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd + ), + ) + optimization_config = OptimizationConfig(objective=objective) + + return BenchmarkProblem( + name="test", + search_space=experiment.search_space, + optimization_config=optimization_config, + num_trials=6, + observe_noise_stds=observe_noise_sd, + optimal_value=0.0, + runner=runner, + ) + + +def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: experiment = get_branin_experiment(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment,