Skip to content

Commit

Permalink
fixing small mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Jun 28, 2024
1 parent 57042c3 commit fbd0a8b
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions benchmarking/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import gc
import random
from collections.abc import Iterable
from datetime import datetime
from itertools import cycle, product
Expand Down Expand Up @@ -218,19 +219,19 @@ def next_obs(
budget: int,
) -> tuple[Tensor, Tensor, Union[Idw, Rbf]]:
mdl = get_mdl(X, Y, prev_mdl)
remaining_horizon = min(horizon, budget)
if remaining_horizon == 1:
h = min(horizon, budget)
if h == 1:
acqfun = qIdwAcquisitionFunction(mdl, c1, c2, valfunc_sampler)
X_opt, _ = optimize_acqf(
acqfun, bounds, 1, n_restarts, raw_samples, {"seed": mk_seed()}
)
return X_opt, torch.nan, mdl

n_restarts_ = n_restarts * remaining_horizon
raw_samples_ = raw_samples * remaining_horizon
n_restarts_ = n_restarts * h
raw_samples_ = raw_samples * h
acqfun = Ms(
mdl,
fantasies_samplers,
fantasies_samplers[: h - 1],
qIdwAcquisitionFunction,
kwargs_factory,
valfunc_sampler=valfunc_sampler,
Expand All @@ -239,7 +240,7 @@ def next_obs(
prev_full_opt = None
else:
prev_full_opt = warmstart_multistep(
acqfun, bounds, n_restarts_, raw_samples_, prev_full_opt
acqfun, bounds, n_restarts_, raw_samples_, prev_full_opt[:, :h]
)
full_opt, tree_vals = optimize_acqf(
acqfun,
Expand Down Expand Up @@ -324,6 +325,8 @@ def run_benchmark(
torch.set_default_device(device)
torch.set_default_dtype(torch.float64)
torch.manual_seed(seed)
np.random.rand(seed)
random.seed(seed)
if setup_callback is not None:
setup_callback()
problem, maxiter, regression_type = get_benchmark_problem(problem_name)
Expand Down Expand Up @@ -371,7 +374,7 @@ def run_benchmarks(
delayed(run_benchmark)(
prob,
method,
seeds[prob][trial],
int(seeds[prob][trial]),
csv,
device,
n_init,
Expand Down

0 comments on commit fbd0a8b

Please sign in to comment.