diff --git a/morl_baselines/multi_policy/capql/capql.py b/morl_baselines/multi_policy/capql/capql.py index e3674706..1ebfd230 100644 --- a/morl_baselines/multi_policy/capql/capql.py +++ b/morl_baselines/multi_policy/capql/capql.py @@ -400,7 +400,17 @@ def train( reset_num_timesteps (bool): Whether to reset the number of timesteps. """ if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "num_eval_weights_for_front": num_eval_weights_for_front, + "num_eval_episodes_for_front": num_eval_episodes_for_front, + "eval_freq": eval_freq, + "reset_num_timesteps": reset_num_timesteps, + } + ) eval_weights = equally_spaced_weights(self.reward_dim, n=num_eval_weights_for_front) diff --git a/morl_baselines/multi_policy/envelope/envelope.py b/morl_baselines/multi_policy/envelope/envelope.py index 41b934bc..449ba4f8 100644 --- a/morl_baselines/multi_policy/envelope/envelope.py +++ b/morl_baselines/multi_policy/envelope/envelope.py @@ -504,7 +504,20 @@ def train( if eval_env is not None: assert ref_point is not None, "Reference point must be provided for the hypervolume computation." if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist() if ref_point is not None else None, + "known_front": known_pareto_front, + "weight": weight.tolist() if weight is not None else None, + "total_episodes": total_episodes, + "reset_num_timesteps": reset_num_timesteps, + "eval_freq": eval_freq, + "num_eval_weights_for_front": num_eval_weights_for_front, + "num_eval_episodes_for_front": num_eval_episodes_for_front, + "reset_learning_starts": reset_learning_starts, + } + ) self.global_step = 0 if reset_num_timesteps else self.global_step self.num_episodes = 0 if reset_num_timesteps else self.num_episodes diff --git a/morl_baselines/multi_policy/gpi_pd/gpi_pd.py b/morl_baselines/multi_policy/gpi_pd/gpi_pd.py index 9bf9a4a4..50153bf8 100644 --- a/morl_baselines/multi_policy/gpi_pd/gpi_pd.py +++ b/morl_baselines/multi_policy/gpi_pd/gpi_pd.py @@ -809,7 +809,17 @@ def train( weight_selection_algo (str): Weight selection algorithm to use. """ if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "num_eval_weights_for_front": num_eval_weights_for_front, + "num_eval_episodes_for_front": num_eval_episodes_for_front, + "timesteps_per_iter": timesteps_per_iter, + "weight_selection_algo": weight_selection_algo, + } + ) max_iter = total_timesteps // timesteps_per_iter linear_support = LinearSupport(num_objectives=self.reward_dim, epsilon=0.0 if weight_selection_algo == "ols" else None) diff --git a/morl_baselines/multi_policy/gpi_pd/gpi_pd_continuous_action.py b/morl_baselines/multi_policy/gpi_pd/gpi_pd_continuous_action.py index d1193b81..959224e6 100644 --- a/morl_baselines/multi_policy/gpi_pd/gpi_pd_continuous_action.py +++ b/morl_baselines/multi_policy/gpi_pd/gpi_pd_continuous_action.py @@ -612,7 +612,18 @@ def train( eval_freq (int): Number of timesteps between evaluations during an iteration. """ if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "num_eval_weights_for_front": num_eval_weights_for_front, + "num_eval_episodes_for_front": num_eval_episodes_for_front, + "weight_selection_algo": weight_selection_algo, + "timesteps_per_iter": timesteps_per_iter, + "eval_freq": eval_freq, + } + ) max_iter = total_timesteps // timesteps_per_iter linear_support = LinearSupport(num_objectives=self.reward_dim, epsilon=0.0 if weight_selection_algo == "ols" else None) diff --git a/morl_baselines/multi_policy/multi_policy_moqlearning/mp_mo_q_learning.py b/morl_baselines/multi_policy/multi_policy_moqlearning/mp_mo_q_learning.py index cf43c50d..e444e793 100644 --- a/morl_baselines/multi_policy/multi_policy_moqlearning/mp_mo_q_learning.py +++ b/morl_baselines/multi_policy/multi_policy_moqlearning/mp_mo_q_learning.py @@ -178,7 +178,17 @@ def train( eval_freq: The frequency of evaluation. """ if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "timesteps_per_iteration": timesteps_per_iteration, + "num_eval_weights_for_front": num_eval_weights_for_front, + "num_eval_episodes_for_front": num_eval_episodes_for_front, + "eval_freq": eval_freq, + } + ) num_iterations = int(total_timesteps / timesteps_per_iteration) if eval_env is None: eval_env = deepcopy(self.env) diff --git a/morl_baselines/multi_policy/pareto_q_learning/pql.py b/morl_baselines/multi_policy/pareto_q_learning/pql.py index 98166bd3..a41b5224 100644 --- a/morl_baselines/multi_policy/pareto_q_learning/pql.py +++ b/morl_baselines/multi_policy/pareto_q_learning/pql.py @@ -222,7 +222,15 @@ def train( if ref_point is None: ref_point = self.ref_point if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "log_every": log_every, + "action_eval": action_eval, + } + ) while self.global_step < total_timesteps: state, _ = self.env.reset() diff --git a/morl_baselines/multi_policy/pcn/pcn.py b/morl_baselines/multi_policy/pcn/pcn.py index de074ef1..af2d7b72 100644 --- a/morl_baselines/multi_policy/pcn/pcn.py +++ b/morl_baselines/multi_policy/pcn/pcn.py @@ -355,7 +355,19 @@ def train( num_points_pf: number of points to sample from pareto front for metrics calculation """ if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + { + "total_timesteps": total_timesteps, + "ref_point": ref_point.tolist(), + "known_front": known_pareto_front, + "num_er_episodes": num_er_episodes, + "num_step_episodes": num_step_episodes, + "num_model_updates": num_model_updates, + "max_return": max_return.tolist(), + "max_buffer_size": max_buffer_size, + "num_points_pf": num_points_pf, + } + ) self.global_step = 0 total_episodes = num_er_episodes n_checkpoints = 0 diff --git a/morl_baselines/multi_policy/pgmorl/pgmorl.py b/morl_baselines/multi_policy/pgmorl/pgmorl.py index 5b6b1864..1ba9addb 100644 --- a/morl_baselines/multi_policy/pgmorl/pgmorl.py +++ b/morl_baselines/multi_policy/pgmorl/pgmorl.py @@ -620,7 +620,9 @@ def train( ): """Trains the agents.""" if self.log: - self.register_additional_config({"ref_point": ref_point.tolist(), "known_front": known_pareto_front}) + self.register_additional_config( + {"total_timesteps": total_timesteps, "ref_point": ref_point.tolist(), "known_front": known_pareto_front} + ) max_iterations = total_timesteps // self.steps_per_iteration // self.num_envs iteration = 0 # Init