Skip to content

Commit

Permalink
Hotfix: OLS ended() method
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Feb 6, 2024
1 parent ebf587a commit d39ac0e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 15 additions & 4 deletions morl_baselines/multi_policy/linear_support/linear_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.weight_support = [] # List of weight vectors for each value vector in the CCS
self.queue = []
self.iteration = 0
self.ols_ended = False
self.verbose = verbose
for w in extrema_weights(self.num_objectives):
self.queue.append((float("inf"), w))
Expand Down Expand Up @@ -102,6 +103,7 @@ def next_weight(
if len(self.queue) == 0:
if self.verbose:
print("There are no corner weights in the queue. Returning None.")
self.ols_ended = True
return None
else:
next_w = self.queue.pop(0)[1]
Expand Down Expand Up @@ -134,8 +136,14 @@ def get_corner_weights(self, top_k: Optional[int] = None) -> List[np.ndarray]:
return weights

def ended(self) -> bool:
"""Returns True if the queue is empty."""
return len(self.queue) == 0
"""Returns True if there are no more corner weights to test.
Warning: This method must be called AFTER calling next_weight().
Ex: w = ols.next_weight()
if ols.ended():
print("OLS ended.")
"""
return self.ols_ended

def add_solution(self, value: np.ndarray, w: np.ndarray) -> List[int]:
"""Add new value vector optimal to weight w.
Expand Down Expand Up @@ -375,10 +383,13 @@ def is_dominated(self, value: np.ndarray) -> bool:
def _solve(w):
return np.array(list(map(float, input().split())), dtype=np.float32)

num_objectives = 3
num_objectives = 2
ols = LinearSupport(num_objectives=num_objectives, epsilon=0.0001, verbose=True)
while not ols.ended():
while True:
w = ols.next_weight()
if ols.ended():
print("OLS ended.")
break
print("w:", w)
value = _solve(w)
ols.add_solution(value, w)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def test_ols():

ols = LinearSupport(num_objectives=2, epsilon=0.1, verbose=False)
policies = []
while not ols.ended():
while True:
w = ols.next_weight()
if ols.ended():
break

new_policy = MOQLearning(
env,
Expand Down

0 comments on commit d39ac0e

Please sign in to comment.