Skip to content

Commit

Permalink
Disabled cache_root for ensemble surrogates
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 12, 2024
1 parent 90516f4 commit 6d350a4
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions obsidian/optimizer/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .base import Optimizer

from obsidian.parameters import ParamSpace, Target, Task
from obsidian.surrogates import SurrogateBoTorch, DNN
from obsidian.surrogates import SurrogateBoTorch, EnsembleModel
from obsidian.acquisition import aq_class_dict, aq_defaults, aq_hp_defaults, valid_aqs
from obsidian.surrogates import model_class_dict
from obsidian.objectives import Index_Objective, Objective_Sequence
Expand All @@ -18,7 +18,7 @@
from botorch.sampling.index_sampler import IndexSampler
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import ModelList
from botorch.models.model import ModelList, Model
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.multi_objective.box_decompositions.non_dominated import NondominatedPartitioning

Expand Down Expand Up @@ -478,6 +478,7 @@ def _parse_aq_kwargs(self,
hps: dict,
m_batch: int,
target_locs: list[int],
model: Model,
X_t_pending: Tensor | None = None,
objective: MCAcquisitionObjective | None = None) -> dict:
"""
Expand Down Expand Up @@ -570,6 +571,9 @@ def _parse_aq_kwargs(self,
w = w/torch.sum(torch.abs(w))
aq_kwargs['scalarization_weights'] = w

if any(isinstance(m, EnsembleModel) for m in model.models):
aq_kwargs['cache_root'] = False

return aq_kwargs

def suggest(self,
Expand Down Expand Up @@ -712,7 +716,7 @@ def suggest(self,
if not isinstance(model, ModelListGP):
samplers = []
for m in model.models:
if isinstance(m, DNN):
if isinstance(m, EnsembleModel):
sampler_i = IndexSampler(sample_shape=torch.Size([optim_samples]), seed=self.seed)
else:
sampler_i = SobolQMCNormalSampler(sample_shape=torch.Size([optim_samples]), seed=self.seed)
Expand Down Expand Up @@ -757,7 +761,9 @@ def suggest(self,
# Use aq_kwargs so that extra unnecessary ones in hps get removed for certain aq funcs
aq_kwargs = {'model': model, 'sampler': sampler, 'X_pending': X_t_pending}

aq_kwargs.update(self._parse_aq_kwargs(aq_str, aq_hps, m_batch, target_locs, X_t_pending, objective))
aq_kwargs.update(self._parse_aq_kwargs(aq_str, aq_hps, m_batch,
target_locs, model,
X_t_pending, objective))

# Raise errors related to certain constraints
if aq_str in ['UCB', 'Mean', 'TS', 'SF', 'SR', 'NIPV']:
Expand Down Expand Up @@ -978,7 +984,9 @@ def evaluate(self,
# Use aq_kwargs so that extra unnecessary ones in hps get removed for certain aq funcs
aq_kwargs = {'model': model, 'sampler': None, 'X_pending': X_t_pending}

aq_kwargs.update(self._parse_aq_kwargs(aq_str, aq_hps, X_suggest.shape[0], target_locs, X_t_pending, objective))
aq_kwargs.update(self._parse_aq_kwargs(aq_str, aq_hps, X_suggest.shape[0],
target_locs, model,
X_t_pending, objective))

# If it's random search, no need to evaluate aq
if aq_str == 'RS':
Expand Down

0 comments on commit 6d350a4

Please sign in to comment.