Skip to content

Commit

Permalink
simplified some calls now that MPC is gone
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Dec 8, 2023
1 parent 60c7f23 commit b054278
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
12 changes: 5 additions & 7 deletions src/globopt/nonmyopic/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
nogil=True,
)
def _compute_myopic_cost(
x_trajectory: Array3d,
x: Array3d,
mdl: RegressorType,
n_samples: int,
c1: float,
Expand All @@ -71,8 +71,7 @@ def _compute_myopic_cost(
y_min = mdl.ym_.min() # ym_ ∈ (1, n_samples, 1)
y_max = mdl.ym_.max()
dym = np.full((1, 1, 1), y_max - y_min)
x_next = x_trajectory[np.newaxis, :, 0, :] # take first element in the horizon
cost = myopic_acquisition(x_next, mdl, c1, c2, None, dym)[0, :, 0] # ∈ (n_samples,)
cost = myopic_acquisition(x.transpose(1, 0, 2), mdl, c1, c2, None, dym)[0, :, 0]

mdl_ = repeat(mdl, n_samples)
lb_ = repeat_along_first_axis(np.expand_dims(lb, 0), n_samples)
Expand Down Expand Up @@ -155,7 +154,7 @@ def _terminal_cost(


def _compute_nonmyopic_cost(
x_trajectory: Array3d,
x: Array3d,
mdl: RegressorType,
n_samples: int,
horizon: int,
Expand All @@ -176,15 +175,14 @@ def _compute_nonmyopic_cost(
dynamics with such predictions."""
np_random = np.random.default_rng(seed)
cost = np.zeros(n_samples)
x_next = x_trajectory[:, 0, np.newaxis, :] # ∈ (n_samples, 1, n_features)
h = 0
while True:
rng = prediction_rng[h] if prediction_rng is not None else None
mdl, y_min, y_max, dym = _advance(x_next, mdl, y_min, y_max, rng)
mdl, y_min, y_max, dym = _advance(x, mdl, y_min, y_max, rng)
h += 1
if h >= horizon:
break
x_next, current_cost = _next_query_point(
x, current_cost = _next_query_point(
mdl, c1, c2, dym, lb, ub, pso_kwargs, np_random
)
cost += (discount**h) * current_cost # accumulate in-place
Expand Down
41 changes: 22 additions & 19 deletions src/globopt/nonmyopic/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,28 @@ def _next_query_point(
lb_acquisition = lb[0]
ub_acquisition = ub[0]
check_acquisition = iteration == 1
vpso_func = lambda x: acquisition(
x.reshape(-1, 1, dim),
mdl,
horizon,
discount,
lb_acquisition,
ub_acquisition,
c1,
c2,
mc_iters,
quasi_mc,
antithetic_variates,
terminal_cost,
pso_kwargs,
check_acquisition,
np_random,
parallel,
False,
)

def vpso_func(x):
return acquisition(
x.transpose(1, 0, 2),
mdl,
horizon,
discount,
lb_acquisition,
ub_acquisition,
c1,
c2,
mc_iters,
quasi_mc,
antithetic_variates,
terminal_cost,
pso_kwargs,
check_acquisition,
np_random,
parallel,
False,
)

x_new, acq_opt, _ = vpso(vpso_func, lb, ub, seed=np_random, **pso_kwargs)
return x_new[0, :dim], acq_opt.item()

Expand Down

0 comments on commit b054278

Please sign in to comment.