Skip to content

Commit

Permalink
Introduce SurrogateTestFunction (#2953)
Browse files Browse the repository at this point in the history
Summary:

**Context**: This next diff will cut over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity.

**Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming.

The name changes I hope to make:
* ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner
* ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate.
* BoTorchTestProblem -> BoTorchTestFunction

**Changes in this diff**:
* Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates, giving it the surrogate-related logic from SurrogateRunner

Reviewed By: saitcakmak

Differential Revision: D64899032
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 25, 2024
1 parent ec5a6bd commit 7de4467
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 33 deletions.
80 changes: 80 additions & 0 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.observation import ObservationFeatures
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
Expand All @@ -22,6 +24,84 @@
from torch import Tensor


@dataclass(kw_only=True)
class SurrogateTestFunction(ParamBasedTestProblem):
"""
Data-generating function for surrogate benchmark problems.
Args:
name: The name of the runner.
outcome_names: Names of outcomes to return in `evaluate_true`, if the
surrogate produces more outcomes than are needed.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
for generating observations. If `None`, `get_surrogate_and_datasets`
must not be None and will be used to generate the surrogate when it
is needed.
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
the surrogate model. If `None`, `get_surrogate_and_datasets` must
not be None and will be used to generate the datasets when they are
needed.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
"""

name: str
outcome_names: list[str]
_surrogate: TorchModelBridge | None = None
_datasets: list[SupervisedDataset] | None = None
get_surrogate_and_datasets: (
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
) = None

def __post_init__(self) -> None:
if self.get_surrogate_and_datasets is None and (
self._surrogate is None or self._datasets is None
):
raise ValueError(
"If `get_surrogate_and_datasets` is None, `_surrogate` "
"and `_datasets` must not be None, and vice versa."
)

def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()

@property
def surrogate(self) -> TorchModelBridge:
if self._surrogate is None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)

@property
def datasets(self) -> list[SupervisedDataset]:
if self._datasets is None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)

def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Don't check surrogate, datasets, or callable
return self.name == other.name


@dataclass
class SurrogateRunner(BenchmarkRunner):
"""Runner for surrogate benchmark problems.
Expand Down
85 changes: 59 additions & 26 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict


from contextlib import nullcontext
from dataclasses import replace
from itertools import product
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np

Expand All @@ -19,13 +20,17 @@
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_test_function,
TestParamBasedTestProblem,
)
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann
from botorch.utils.transforms import normalize
Expand Down Expand Up @@ -130,34 +135,35 @@ def test_synthetic_runner(self) -> None:
for num_outcomes in (1, 2)
for noise_std in (0.0, [float(i) for i in range(num_outcomes)])
]
for test_problem, noise_std, num_outcomes in botorch_cases + param_based_cases:
is_constrained = isinstance(
test_problem, BoTorchTestProblem
) and isinstance(test_problem.botorch_problem, ConstrainedHartmann)
num_constraints = 1 if is_constrained else 0
outcome_names = [
f"objective_{i}" for i in range(num_outcomes - num_constraints)
] + ["constraint"] * num_constraints
surrogate_cases = [
(get_soo_surrogate_test_function(lazy=False), noise_std, 1)
for noise_std in (0.0, 1.0, [0.0], [1.0])
]
for test_problem, noise_std, num_outcomes in (
botorch_cases + param_based_cases + surrogate_cases
):
# Set up outcome names
if isinstance(test_problem, BoTorchTestProblem):
if isinstance(test_problem.botorch_problem, ConstrainedHartmann):
outcome_names = ["objective_0", "constraint"]
else:
outcome_names = ["objective_0"]
elif isinstance(test_problem, TestParamBasedTestProblem):
outcome_names = [f"objective_{i}" for i in range(num_outcomes)]
else: # SurrogateTestFunction
outcome_names = ["branin"]

# Set up runner
runner = ParamBasedTestProblemRunner(
test_problem=test_problem,
outcome_names=outcome_names,
noise_std=noise_std,
)
modified_bounds = (
test_problem.modified_bounds
if isinstance(test_problem, BoTorchTestProblem)
else None
)

test_description: str = (
f"test problem: {test_problem.__class__.__name__}, "
f"modified_bounds: {modified_bounds}, "
f"noise_std: {noise_std}."
)
is_botorch = isinstance(test_problem, BoTorchTestProblem)

with self.subTest(f"Test basic construction, {test_description}"):
test_description = f"{test_problem=}, {noise_std=}"
with self.subTest(
f"Test basic construction, {test_problem=}, {noise_std=}"
):
self.assertIs(runner.test_problem, test_problem)
self.assertEqual(runner.outcome_names, outcome_names)
if isinstance(noise_std, list):
Expand All @@ -183,6 +189,7 @@ def test_synthetic_runner(self) -> None:
test_problem.botorch_problem.bounds.dtype, torch.double
)

is_botorch = isinstance(test_problem, BoTorchTestProblem)
with self.subTest(f"test `get_Y_true()`, {test_description}"):
dim = 6 if is_botorch else 9
X = torch.rand(1, dim, dtype=torch.double)
Expand All @@ -195,7 +202,20 @@ def test_synthetic_runner(self) -> None:
)
params = dict(zip(param_names, (x.item() for x in X.unbind(-1))))

Y = runner.get_Y_true(params=params)
with (
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"branin": [4.2]}, None),
)
):
Y = runner.get_Y_true(params=params)
oracle = runner.evaluate_oracle(parameters=params)

if (
isinstance(test_problem, BoTorchTestProblem)
and test_problem.modified_bounds is not None
Expand All @@ -221,12 +241,13 @@ def test_synthetic_runner(self) -> None:
)
else:
expected_Y = obj
elif isinstance(test_problem, SurrogateTestFunction):
expected_Y = torch.tensor([4.2], dtype=torch.double)
else:
expected_Y = torch.full(
torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double
)
self.assertTrue(torch.allclose(Y, expected_Y))
oracle = runner.evaluate_oracle(parameters=params)
self.assertTrue(np.equal(Y.numpy(), oracle).all())

with self.subTest(f"test `run()`, {test_description}"):
Expand All @@ -237,7 +258,19 @@ def test_synthetic_runner(self) -> None:
trial.arms = [arm]
trial.arm = arm
trial.index = 0
res = runner.run(trial=trial)

with (
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"branin": [4.2]}, None),
)
):
res = runner.run(trial=trial)
self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys())
self.assertEqual({"0_0"}, res["Ys"].keys())

Expand Down
78 changes: 74 additions & 4 deletions ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,82 @@
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_soo_surrogate
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_legacy,
get_soo_surrogate_test_function,
)


class TestSurrogateTestFunction(TestCase):
def test_surrogate_test_function(self) -> None:
# Construct a search space with log-scale parameters.
for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}):
with self.subTest(noise_std=noise_std):
surrogate = MagicMock()
mock_mean = torch.tensor([[0.1234]], dtype=torch.double)
surrogate.predict = MagicMock(return_value=(mock_mean, 0))
surrogate.device = torch.device("cpu")
surrogate.dtype = torch.double
test_function = SurrogateTestFunction(
name="test test function",
outcome_names=["dummy metric"],
_surrogate=surrogate,
_datasets=[],
)
self.assertEqual(test_function.name, "test test function")
self.assertIs(test_function.surrogate, surrogate)

def test_lazy_instantiation(self) -> None:
test_function = get_soo_surrogate_test_function()

self.assertIsNone(test_function._surrogate)
self.assertIsNone(test_function._datasets)

# Accessing `surrogate` sets datasets and surrogate
self.assertIsInstance(test_function.surrogate, TorchModelBridge)
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
self.assertIsInstance(test_function._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
test_function = get_soo_surrogate_test_function()
self.assertIsInstance(test_function.datasets, list)
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
self.assertIsInstance(test_function._datasets, list)

with patch.object(
test_function,
"get_surrogate_and_datasets",
wraps=test_function.get_surrogate_and_datasets,
) as mock_get_surrogate_and_datasets:
test_function.surrogate
mock_get_surrogate_and_datasets.assert_not_called()

def test_instantiation_raises_with_missing_args(self) -> None:
with self.assertRaisesRegex(
ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and "
):
SurrogateTestFunction(name="test runner", outcome_names=[])

def test_equality(self) -> None:
def _construct_test_function(name: str) -> SurrogateTestFunction:
return SurrogateTestFunction(
name=name,
_surrogate=MagicMock(),
_datasets=[],
outcome_names=["dummy_metric"],
)

runner_1 = _construct_test_function("test 1")
runner_2 = _construct_test_function("test 2")
runner_1a = _construct_test_function("test 1")
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)


class TestSurrogateRunner(TestCase):
Expand Down Expand Up @@ -49,7 +119,7 @@ def test_surrogate_runner(self) -> None:
self.assertEqual(runner.noise_stds, noise_std)

def test_lazy_instantiation(self) -> None:
runner = get_soo_surrogate().runner
runner = get_soo_surrogate_legacy().runner

self.assertIsNone(runner._surrogate)
self.assertIsNone(runner._datasets)
Expand All @@ -60,7 +130,7 @@ def test_lazy_instantiation(self) -> None:
self.assertIsInstance(runner._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
runner = get_soo_surrogate().runner
runner = get_soo_surrogate_legacy().runner
self.assertIsInstance(runner.datasets, list)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)
Expand Down
Loading

0 comments on commit 7de4467

Please sign in to comment.