Skip to content

Commit

Permalink
Merge pull request #83 from LucasAlegre/train-config
Browse files Browse the repository at this point in the history
Log all training parameters given to the algorithm
  • Loading branch information
ffelten authored Dec 13, 2023
2 parents 72161ed + ffbf1ec commit 0907003
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 8 deletions.
12 changes: 11 additions & 1 deletion morl_baselines/multi_policy/capql/capql.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,17 @@ def train(
reset_num_timesteps (bool): Whether to reset the number of timesteps.
"""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"num_eval_weights_for_front": num_eval_weights_for_front,
"num_eval_episodes_for_front": num_eval_episodes_for_front,
"eval_freq": eval_freq,
"reset_num_timesteps": reset_num_timesteps,
}
)

eval_weights = equally_spaced_weights(self.reward_dim, n=num_eval_weights_for_front)

Expand Down
15 changes: 14 additions & 1 deletion morl_baselines/multi_policy/envelope/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,20 @@ def train(
if eval_env is not None:
assert ref_point is not None, "Reference point must be provided for the hypervolume computation."
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist() if ref_point is not None else None,
"known_front": known_pareto_front,
"weight": weight.tolist() if weight is not None else None,
"total_episodes": total_episodes,
"reset_num_timesteps": reset_num_timesteps,
"eval_freq": eval_freq,
"num_eval_weights_for_front": num_eval_weights_for_front,
"num_eval_episodes_for_front": num_eval_episodes_for_front,
"reset_learning_starts": reset_learning_starts,
}
)

self.global_step = 0 if reset_num_timesteps else self.global_step
self.num_episodes = 0 if reset_num_timesteps else self.num_episodes
Expand Down
12 changes: 11 additions & 1 deletion morl_baselines/multi_policy/gpi_pd/gpi_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,17 @@ def train(
weight_selection_algo (str): Weight selection algorithm to use.
"""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"num_eval_weights_for_front": num_eval_weights_for_front,
"num_eval_episodes_for_front": num_eval_episodes_for_front,
"timesteps_per_iter": timesteps_per_iter,
"weight_selection_algo": weight_selection_algo,
}
)
max_iter = total_timesteps // timesteps_per_iter
linear_support = LinearSupport(num_objectives=self.reward_dim, epsilon=0.0 if weight_selection_algo == "ols" else None)

Expand Down
13 changes: 12 additions & 1 deletion morl_baselines/multi_policy/gpi_pd/gpi_pd_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,18 @@ def train(
eval_freq (int): Number of timesteps between evaluations during an iteration.
"""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"num_eval_weights_for_front": num_eval_weights_for_front,
"num_eval_episodes_for_front": num_eval_episodes_for_front,
"weight_selection_algo": weight_selection_algo,
"timesteps_per_iter": timesteps_per_iter,
"eval_freq": eval_freq,
}
)
max_iter = total_timesteps // timesteps_per_iter
linear_support = LinearSupport(num_objectives=self.reward_dim, epsilon=0.0 if weight_selection_algo == "ols" else None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,17 @@ def train(
eval_freq: The frequency of evaluation.
"""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"timesteps_per_iteration": timesteps_per_iteration,
"num_eval_weights_for_front": num_eval_weights_for_front,
"num_eval_episodes_for_front": num_eval_episodes_for_front,
"eval_freq": eval_freq,
}
)
num_iterations = int(total_timesteps / timesteps_per_iteration)
if eval_env is None:
eval_env = deepcopy(self.env)
Expand Down
10 changes: 9 additions & 1 deletion morl_baselines/multi_policy/pareto_q_learning/pql.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,15 @@ def train(
if ref_point is None:
ref_point = self.ref_point
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"log_every": log_every,
"action_eval": action_eval,
}
)

while self.global_step < total_timesteps:
state, _ = self.env.reset()
Expand Down
14 changes: 13 additions & 1 deletion morl_baselines/multi_policy/pcn/pcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,19 @@ def train(
num_points_pf: number of points to sample from pareto front for metrics calculation
"""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{
"total_timesteps": total_timesteps,
"ref_point": ref_point.tolist(),
"known_front": known_pareto_front,
"num_er_episodes": num_er_episodes,
"num_step_episodes": num_step_episodes,
"num_model_updates": num_model_updates,
"max_return": max_return.tolist(),
"max_buffer_size": max_buffer_size,
"num_points_pf": num_points_pf,
}
)
self.global_step = 0
total_episodes = num_er_episodes
n_checkpoints = 0
Expand Down
4 changes: 3 additions & 1 deletion morl_baselines/multi_policy/pgmorl/pgmorl.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ def train(
):
"""Trains the agents."""
if self.log:
self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front})
self.register_additional_config(
{"total_timesteps": total_timesteps, "ref_point": ref_point.tolist(), "known_front": known_pareto_front}
)
max_iterations = total_timesteps // self.steps_per_iteration // self.num_envs
iteration = 0
# Init
Expand Down

0 comments on commit 0907003

Please sign in to comment.