Skip to content

Commit

Permalink
Merge pull request #130 from LTluttmann/fix-warmup-bug
Browse files Browse the repository at this point in the history
[BugFix] fix bugs in warmup baseline
  • Loading branch information
fedebotu authored Mar 9, 2024
2 parents fd58215 + 246f33b commit c817264
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions rl4co/models/rl/reinforce/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ def eval(self, td, reward, env=None):
v_b, l_b = self.baseline.eval(td, reward, env)
v_wb, l_wb = self.warmup_baseline.eval(td, reward, env)
# Return convex combination of baseline and of loss
return self.alpha * v_b + (1 - self.alpha) * v_wb, self.alpha * l_b + (
1 - self.alpha * l_wb
return (
self.alpha * v_b + (1 - self.alpha) * v_wb,
self.alpha * l_b + (1 - self.alpha) * l_wb,
)

def epoch_callback(self, *args, **kw):
# Need to call epoch callback of inner model (also after first epoch if we have not used it)
self.baseline.epoch_callback(*args, **kw)
self.alpha = (kw["epoch"] + 1) / float(self.n_epochs)
if kw["epoch"] < self.n_epochs:
self.alpha = (kw["epoch"] + 1) / float(self.n_epochs)
log.info("Set warmup alpha = {}".format(self.alpha))


Expand Down Expand Up @@ -293,9 +294,9 @@ def get_reinforce_baseline(name, **kw):
return WarmupBaseline(
RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta
)

if name is None:
name = "no" # default to no baseline
name = "no" # default to no baseline
baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None)
if baseline_cls is None:
raise ValueError(
Expand Down

0 comments on commit c817264

Please sign in to comment.