Skip to content

Commit

Permalink
Clean ModelEnv termination function
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Dec 4, 2024
1 parent 57d8fee commit f3f9670
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions morl_baselines/common/model_based/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,18 @@ def __init__(self, model, env_id=None, rew_dim=1):
"""
self.model = model
self.rew_dim = rew_dim
if env_id == "Hopper-v2" or env_id == "Hopper-v4" or env_id == "mo-hopper-v4" or env_id == "mo-hopper-2d-v4":
if "hopper" in env_id:
self.termination_func = termination_fn_hopper
elif env_id == "HalfCheetah-v2" or env_id == "mo-halfcheetah-v4":
elif "halfcheetah" in env_id:
self.termination_func = termination_fn_false
elif env_id == "LunarLanderContinuous-v2" or env_id.startswith("mo-lunar-lander"):
elif "lunar-lander" in env_id:
self.termination_func = termination_fn_false
elif env_id == "ReacherMultiTask-v0" or env_id.startswith("mo-reacher-v"):
elif "mo-reacher" in env_id:
self.termination_func = termination_fn_false
elif env_id == "MountainCarContinuous-v0" or env_id.startswith("mo-mountaincar"):
elif "mountaincar" in env_id:
self.termination_func = termination_fn_mountaincar
elif env_id == "minecart-v0":
elif "minecart" in env_id:
self.termination_func = termination_fn_minecart
elif env_id == "SEIRsingle-v0":
self.termination_func = termination_fn_false
elif env_id == "mo-highway-fast-v0" or env_id == "mo-highway-v0":
self.termination_func = termination_fn_false
elif env_id == "deep-sea-treasure-v0":
Expand Down

0 comments on commit f3f9670

Please sign in to comment.