From dbcf5ed25a3cfc7c4e324d889442f028e0d283ce Mon Sep 17 00:00:00 2001 From: Chris Nota Date: Sun, 25 Feb 2024 18:21:31 -0500 Subject: [PATCH] Add save_freq argument and refactor scripts (#305) * add save_freq argument and refactor scripts * run formatter * fix test --- all/experiments/parallel_env_experiment.py | 9 ++- .../parallel_env_experiment_test.py | 2 +- all/experiments/run_experiment.py | 2 + all/experiments/single_env_experiment.py | 7 ++ all/scripts/classic.py | 63 ---------------- all/scripts/continuous.py | 75 ------------------- all/scripts/{atari.py => train.py} | 38 ++++++---- all/scripts/train_atari.py | 18 +++++ all/scripts/train_classic.py | 18 +++++ all/scripts/train_continuous.py | 18 +++++ all/scripts/train_mujoco.py | 18 +++++ ...ent_atari.py => train_multiagent_atari.py} | 7 +- all/scripts/train_pybullet.py | 18 +++++ all/scripts/watch_classic.py | 6 +- all/scripts/watch_continuous.py | 18 ++--- all/scripts/watch_mujoco.py | 27 +++++++ all/scripts/watch_multiagent_atari.py | 2 +- all/scripts/watch_pybullet.py | 29 +++++++ setup.py | 12 ++- 19 files changed, 211 insertions(+), 176 deletions(-) delete mode 100644 all/scripts/classic.py delete mode 100644 all/scripts/continuous.py rename all/scripts/{atari.py => train.py} (58%) create mode 100644 all/scripts/train_atari.py create mode 100644 all/scripts/train_classic.py create mode 100644 all/scripts/train_continuous.py create mode 100644 all/scripts/train_mujoco.py rename all/scripts/{multiagent_atari.py => train_multiagent_atari.py} (92%) create mode 100644 all/scripts/train_pybullet.py create mode 100644 all/scripts/watch_mujoco.py create mode 100644 all/scripts/watch_pybullet.py diff --git a/all/experiments/parallel_env_experiment.py b/all/experiments/parallel_env_experiment.py index cc462974..ff0327cf 100644 --- a/all/experiments/parallel_env_experiment.py +++ b/all/experiments/parallel_env_experiment.py @@ -21,6 +21,7 @@ def __init__( logdir="runs", quiet=False, render=False, + save_freq=100, verbose=True, logger="tensorboard", ): @@ -37,6 +38,7 @@ def __init__( self._preset = preset self._agent = preset.agent(logger=self._logger, train_steps=train_steps) self._render = render + self._save_freq = save_freq # training state self._returns = [] @@ -87,9 +89,10 @@ def train(self, frames=np.inf, episodes=np.inf): for i in range(num_envs): if dones[i]: self._log_training_episode(returns[i], episode_lengths[i], fps) + self._save_model() returns[i] = 0 episode_lengths[i] = 0 - self._episode += episodes_completed + self._episode += 1 def test(self, episodes=100): test_agent = self._preset.parallel_test_agent() @@ -144,3 +147,7 @@ def _make_logger(self, logdir, agent_name, env_name, verbose, logger): return ExperimentLogger( self, agent_name, env_name, verbose=verbose, logdir=logdir ) + + def _save_model(self): + if self._save_freq != float("inf") and self._episode % self._save_freq == 0: + self.save() diff --git a/all/experiments/parallel_env_experiment_test.py b/all/experiments/parallel_env_experiment_test.py index f65d9e49..f922fe51 100644 --- a/all/experiments/parallel_env_experiment_test.py +++ b/all/experiments/parallel_env_experiment_test.py @@ -36,7 +36,7 @@ def test_writes_training_returns_episode(self): self.experiment.train(episodes=4) np.testing.assert_equal( self.experiment._logger.data["eval/returns/episode"]["steps"], - np.array([1, 2, 3, 3]), + np.array([1, 2, 3, 4]), ) np.testing.assert_equal( self.experiment._logger.data["eval/returns/episode"]["values"], diff --git a/all/experiments/run_experiment.py b/all/experiments/run_experiment.py index 3ed2021d..a89d3e9c 100644 --- a/all/experiments/run_experiment.py +++ b/all/experiments/run_experiment.py @@ -11,6 +11,7 @@ def run_experiment( logdir="runs", quiet=False, render=False, + save_freq=100, test_episodes=100, verbose=True, logger="tensorboard", @@ -32,6 +33,7 @@ def run_experiment( logdir=logdir, quiet=quiet, render=render, + save_freq=save_freq, verbose=verbose, logger=logger, ) diff --git a/all/experiments/single_env_experiment.py b/all/experiments/single_env_experiment.py index d3ecb0ed..23e666db 100644 --- a/all/experiments/single_env_experiment.py +++ b/all/experiments/single_env_experiment.py @@ -19,6 +19,7 @@ def __init__( logdir="runs", quiet=False, render=False, + save_freq=100, verbose=True, logger="tensorboard", ): @@ -33,6 +34,7 @@ def __init__( self._render = render self._frame = 1 self._episode = 1 + self._save_freq = 100 if render: self._env.render(mode="human") @@ -88,6 +90,7 @@ def _run_training_episode(self): # log the results self._log_training_episode(returns, episode_length, fps) + self._save_model() # update experiment state self._episode += 1 @@ -121,3 +124,7 @@ def _make_logger(self, logdir, agent_name, env_name, verbose, logger): return ExperimentLogger( self, agent_name, env_name, verbose=verbose, logdir=logdir ) + + def _save_model(self): + if self._save_freq != float("inf") and self._episode % self._save_freq == 0: + self.save() diff --git a/all/scripts/classic.py b/all/scripts/classic.py deleted file mode 100644 index 59808c84..00000000 --- a/all/scripts/classic.py +++ /dev/null @@ -1,63 +0,0 @@ -import argparse - -from all.environments import GymEnvironment -from all.experiments import run_experiment -from all.presets import classic_control - - -def main(): - parser = argparse.ArgumentParser(description="Run a classic control benchmark.") - parser.add_argument("env", help="Name of the env (e.g. CartPole-v1).") - parser.add_argument( - "agent", help="Name of the agent (e.g. dqn). See presets for available agents." - ) - parser.add_argument( - "--device", - default="cuda", - help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", - ) - parser.add_argument( - "--frames", type=int, default=50000, help="The number of training frames." - ) - parser.add_argument( - "--render", action="store_true", default=False, help="Render the environment." - ) - parser.add_argument("--logdir", default="runs", help="The base logging directory.") - parser.add_argument( - "--logger", - default="tensorboard", - help="The backend used for tracking experiment metrics.", - ) - parser.add_argument( - "--hyperparameters", - default=[], - nargs="*", - help="Custom hyperparameters, in the format hyperparameter1=value1 hyperparameter2=value2 etc.", - ) - args = parser.parse_args() - - env = GymEnvironment(args.env, device=args.device) - - agent_name = args.agent - agent = getattr(classic_control, agent_name) - agent = agent.device(args.device) - - # parse hyperparameters - hyperparameters = {} - for hp in args.hyperparameters: - key, value = hp.split("=") - hyperparameters[key] = type(agent.default_hyperparameters[key])(value) - agent = agent.hyperparameters(**hyperparameters) - - run_experiment( - agent, - env, - frames=args.frames, - render=args.render, - logdir=args.logdir, - logger=args.logger, - ) - - -if __name__ == "__main__": - main() diff --git a/all/scripts/continuous.py b/all/scripts/continuous.py deleted file mode 100644 index a46b6be3..00000000 --- a/all/scripts/continuous.py +++ /dev/null @@ -1,75 +0,0 @@ -# pylint: disable=unused-import -import argparse - -from all.environments import GymEnvironment, PybulletEnvironment -from all.experiments import run_experiment -from all.presets import continuous - -# see also: PybulletEnvironment.short_names -ENVS = { - "mountaincar": "MountainCarContinuous-v0", - "lander": "LunarLanderContinuous-v2", -} - - -def main(): - parser = argparse.ArgumentParser(description="Run a continuous actions benchmark.") - parser.add_argument("env", help="Name of the env (e.g. 'lander', 'cheetah')") - parser.add_argument( - "agent", help="Name of the agent (e.g. ddpg). See presets for available agents." - ) - parser.add_argument( - "--device", - default="cuda", - help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", - ) - parser.add_argument( - "--frames", type=int, default=2e6, help="The number of training frames." - ) - parser.add_argument( - "--render", action="store_true", default=False, help="Render the environment." - ) - parser.add_argument("--logdir", default="runs", help="The base logging directory.") - parser.add_argument( - "--logger", - default="tensorboard", - help="The backend used for tracking experiment metrics.", - ) - parser.add_argument( - "--hyperparameters", - default=[], - nargs="*", - help="Custom hyperparameters, in the format hyperparameter1=value1 hyperparameter2=value2 etc.", - ) - args = parser.parse_args() - - if args.env in ENVS: - env = GymEnvironment(ENVS[args.env], device=args.device) - elif "BulletEnv" in args.env or args.env in PybulletEnvironment.short_names: - env = PybulletEnvironment(args.env, device=args.device) - else: - env = GymEnvironment(args.env, device=args.device) - - agent_name = args.agent - agent = getattr(continuous, agent_name) - agent = agent.device(args.device) - - # parse hyperparameters - hyperparameters = {} - for hp in args.hyperparameters: - key, value = hp.split("=") - hyperparameters[key] = type(agent.default_hyperparameters[key])(value) - agent = agent.hyperparameters(**hyperparameters) - - run_experiment( - agent, - env, - frames=args.frames, - render=args.render, - logdir=args.logdir, - logger=args.logger, - ) - - -if __name__ == "__main__": - main() diff --git a/all/scripts/atari.py b/all/scripts/train.py similarity index 58% rename from all/scripts/atari.py rename to all/scripts/train.py index 804f46d6..d16e44c7 100644 --- a/all/scripts/atari.py +++ b/all/scripts/train.py @@ -1,15 +1,21 @@ import argparse -from all.environments import AtariEnvironment from all.experiments import run_experiment -from all.presets import atari -def main(): - parser = argparse.ArgumentParser(description="Run an Atari benchmark.") - parser.add_argument("env", help="Name of the Atari game (e.g. Pong).") +def train( + presets, + env_constructor, + description="Train an RL agent", + env_help="Name of the environment (e.g., 'CartPole-v0')", + default_frames=40e6, +): + # parse command line args + parser = argparse.ArgumentParser(description=description) + parser.add_argument("env", help=env_help) parser.add_argument( - "agent", help="Name of the agent (e.g. dqn). See presets for available agents." + "agent", + help="Name of the agent (e.g. 'dqn'). See presets for available agents.", ) parser.add_argument( "--device", @@ -17,7 +23,10 @@ def main(): help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", ) parser.add_argument( - "--frames", type=int, default=40e6, help="The number of training frames." + "--frames", + type=int, + default=default_frames, + help="The number of training frames.", ) parser.add_argument( "--render", action="store_true", default=False, help="Render the environment." @@ -28,13 +37,18 @@ def main(): default="tensorboard", help="The backend used for tracking experiment metrics.", ) + parser.add_argument( + "--save_freq", default=100, help="How often to save the model, in episodes." + ) parser.add_argument("--hyperparameters", default=[], nargs="*") args = parser.parse_args() - env = AtariEnvironment(args.env, device=args.device) + # construct the environment + env = env_constructor(args.env, device=args.device) + # construct the agents agent_name = args.agent - agent = getattr(atari, agent_name) + agent = getattr(presets, agent_name) agent = agent.device(args.device) # parse hyperparameters @@ -44,6 +58,7 @@ def main(): hyperparameters[key] = type(agent.default_hyperparameters[key])(value) agent = agent.hyperparameters(**hyperparameters) + # run the experiment run_experiment( agent, env, @@ -51,8 +66,5 @@ def main(): render=args.render, logdir=args.logdir, logger=args.logger, + save_freq=args.save_freq, ) - - -if __name__ == "__main__": - main() diff --git a/all/scripts/train_atari.py b/all/scripts/train_atari.py new file mode 100644 index 00000000..20c968bb --- /dev/null +++ b/all/scripts/train_atari.py @@ -0,0 +1,18 @@ +from all.environments import AtariEnvironment +from all.presets import atari + +from .train import train + + +def main(): + train( + atari, + AtariEnvironment, + description="Train an agent on an Atari environment.", + env_help="The name of the environment (e.g., 'Pong').", + default_frames=40e6, + ) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/train_classic.py b/all/scripts/train_classic.py new file mode 100644 index 00000000..2fd4e1cd --- /dev/null +++ b/all/scripts/train_classic.py @@ -0,0 +1,18 @@ +from all.environments import GymEnvironment +from all.presets import classic_control + +from .train import train + + +def main(): + train( + classic_control, + GymEnvironment, + description="Train an agent on an classic control environment.", + env_help="The name of the environment (e.g., CartPole-v0).", + default_frames=50000, + ) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/train_continuous.py b/all/scripts/train_continuous.py new file mode 100644 index 00000000..981d730d --- /dev/null +++ b/all/scripts/train_continuous.py @@ -0,0 +1,18 @@ +from all.environments import GymEnvironment +from all.presets import continuous + +from .train import train + + +def main(): + train( + continuous, + GymEnvironment, + description="Train an agent on a continuous control environment.", + env_help="The name of the environment (e.g., LunarLanderContinuous-v2).", + default_frames=10e6, + ) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/train_mujoco.py b/all/scripts/train_mujoco.py new file mode 100644 index 00000000..8ebff69d --- /dev/null +++ b/all/scripts/train_mujoco.py @@ -0,0 +1,18 @@ +from all.environments import MujocoEnvironment +from all.presets import continuous + +from .train import train + + +def main(): + train( + continuous, + MujocoEnvironment, + description="Train an agent on an Mujoco environment.", + env_help="The name of the environment (e.g., Ant-v4).", + default_frames=10e6, + ) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/multiagent_atari.py b/all/scripts/train_multiagent_atari.py similarity index 92% rename from all/scripts/multiagent_atari.py rename to all/scripts/train_multiagent_atari.py index a76acfe7..255f8c90 100644 --- a/all/scripts/multiagent_atari.py +++ b/all/scripts/train_multiagent_atari.py @@ -13,7 +13,7 @@ def __init__(self, state_space, action_space): def main(): parser = argparse.ArgumentParser(description="Run an multiagent Atari benchmark.") - parser.add_argument("env", help="Name of the Atari game (e.g. pong_v2).") + parser.add_argument("env", help="Name of the Atari game (e.g. pong_v3).") parser.add_argument("agents", nargs="*", help="List of agents.") parser.add_argument( "--device", @@ -28,6 +28,9 @@ def main(): parser.add_argument( "--frames", type=int, default=40e6, help="The number of training frames." ) + parser.add_argument( + "--save_freq", default=100, help="How often to save the model, in episodes." + ) parser.add_argument( "--render", action="store_true", default=False, help="Render the environment." ) @@ -57,9 +60,11 @@ def main(): IndependentMultiagentPreset("Independent", args.device, presets), env, verbose=False, + save_freq=args.save_freq, render=args.render, logger=args.logger, ) + experiment.save() experiment.train(frames=args.frames) experiment.save() experiment.test(episodes=100) diff --git a/all/scripts/train_pybullet.py b/all/scripts/train_pybullet.py new file mode 100644 index 00000000..45d19512 --- /dev/null +++ b/all/scripts/train_pybullet.py @@ -0,0 +1,18 @@ +from all.environments import PybulletEnvironment +from all.presets import continuous + +from .train import train + + +def main(): + train( + continuous, + PybulletEnvironment, + description="Train an agent on an PyBullet environment.", + env_help="The name of the environment (e.g., AntBulletEnv-v0).", + default_frames=10e6, + ) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/watch_classic.py b/all/scripts/watch_classic.py index 239f0964..3cf78014 100644 --- a/all/scripts/watch_classic.py +++ b/all/scripts/watch_classic.py @@ -5,10 +5,8 @@ def main(): - parser = argparse.ArgumentParser(description="Run an Atari benchmark.") - parser.add_argument( - "env", help="Name of the environment (e.g. RoboschoolHalfCheetah-v1" - ) + parser = argparse.ArgumentParser(description="Watch a classic control agent.") + parser.add_argument("env", help="Name of the environment (e.g. CartPole-v0)") parser.add_argument("filename", help="File where the model was saved.") parser.add_argument( "--device", diff --git a/all/scripts/watch_continuous.py b/all/scripts/watch_continuous.py index 903f83d7..3abaad9a 100644 --- a/all/scripts/watch_continuous.py +++ b/all/scripts/watch_continuous.py @@ -1,15 +1,14 @@ -# pylint: disable=unused-import import argparse -from all.environments import GymEnvironment, PybulletEnvironment +from all.environments import GymEnvironment from all.experiments import load_and_watch -from .continuous import ENVS - def main(): parser = argparse.ArgumentParser(description="Watch a continuous agent.") - parser.add_argument("env", help="ID of the Environment") + parser.add_argument( + "env", help="Name of the environment (e.g., LunarLanderContinuous-v2)" + ) parser.add_argument("filename", help="File where the model was saved.") parser.add_argument( "--device", @@ -22,14 +21,7 @@ def main(): help="Playback speed", ) args = parser.parse_args() - - if args.env in ENVS: - env = GymEnvironment(args.env, device=args.device, render_mode="human") - elif "BulletEnv" in args.env or args.env in PybulletEnvironment.short_names: - env = PybulletEnvironment(args.env, device=args.device, render_mode="human") - else: - env = GymEnvironment(args.env, device=args.device, render_mode="human") - + env = GymEnvironment(args.env, device=args.device, render_mode="human") load_and_watch(args.filename, env, fps=args.fps) diff --git a/all/scripts/watch_mujoco.py b/all/scripts/watch_mujoco.py new file mode 100644 index 00000000..8add986f --- /dev/null +++ b/all/scripts/watch_mujoco.py @@ -0,0 +1,27 @@ +import argparse + +from all.environments import MujocoEnvironment +from all.experiments import load_and_watch + + +def main(): + parser = argparse.ArgumentParser(description="Watch a mujoco agent.") + parser.add_argument("env", help="ID of the Environment") + parser.add_argument("filename", help="File where the model was saved.") + parser.add_argument( + "--device", + default="cuda", + help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", + ) + parser.add_argument( + "--fps", + default=120, + help="Playback speed", + ) + args = parser.parse_args() + env = MujocoEnvironment(args.env, device=args.device, render_mode="human") + load_and_watch(args.filename, env, fps=args.fps) + + +if __name__ == "__main__": + main() diff --git a/all/scripts/watch_multiagent_atari.py b/all/scripts/watch_multiagent_atari.py index 9d16f1c9..b6dacf8d 100644 --- a/all/scripts/watch_multiagent_atari.py +++ b/all/scripts/watch_multiagent_atari.py @@ -34,7 +34,7 @@ def watch_episode(env, agent, fps): def main(): parser = argparse.ArgumentParser(description="Watch pretrained multiagent atari") - parser.add_argument("env", help="Name of the Atari game (e.g. pong-v1)") + parser.add_argument("env", help="Name of the Atari game (e.g. pong_v3)") parser.add_argument("filename", help="File where the model was saved.") parser.add_argument( "--device", diff --git a/all/scripts/watch_pybullet.py b/all/scripts/watch_pybullet.py new file mode 100644 index 00000000..1ca3a1c5 --- /dev/null +++ b/all/scripts/watch_pybullet.py @@ -0,0 +1,29 @@ +# pylint: disable=unused-import +import argparse + +from all.environments import PybulletEnvironment +from all.experiments import load_and_watch + + +def main(): + parser = argparse.ArgumentParser(description="Watch a PyBullet agent.") + parser.add_argument("env", help="Name of the environment (e.g., AntBulletEnv-v0)") + parser.add_argument("filename", help="File where the model was saved.") + parser.add_argument( + "--device", + default="cuda", + help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)", + ) + parser.add_argument( + "--fps", + default=120, + help="Playback speed", + ) + args = parser.parse_args() + env = PybulletEnvironment(args.env, device=args.device) + env.render(mode="human") # needed for pybullet envs + load_and_watch(args.filename, env, fps=args.fps) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 3f0ee0b9..da27aca9 100644 --- a/setup.py +++ b/setup.py @@ -59,15 +59,19 @@ author_email="cnota@cs.umass.edu", entry_points={ "console_scripts": [ - "all-atari=all.scripts.atari:main", - "all-classic=all.scripts.classic:main", - "all-continuous=all.scripts.continuous:main", - "all-multiagent-atari=all.scripts.multiagent_atari:main", "all-plot=all.scripts.plot:main", + "all-atari=all.scripts.train_atari:main", + "all-classic=all.scripts.train_classic:main", + "all-continuous=all.scripts.train_continuous:main", + "all-mujoco=all.scripts.train_mujoco:main", + "all-multiagent-atari=all.scripts.train_multiagent_atari:main", + "all-pybullet=all.scripts.train_pybullet:main", "all-watch-atari=all.scripts.watch_atari:main", "all-watch-classic=all.scripts.watch_classic:main", "all-watch-continuous=all.scripts.watch_continuous:main", + "all-watch-mujoco=all.scripts.watch_mujoco:main", "all-watch-multiagent-atari=all.scripts.watch_multiagent_atari:main", + "all-watch-pybullet=all.scripts.watch_pybullet:main", ], }, install_requires=[