Skip to content

Commit

Permalink
Add save_freq argument and refactor scripts (#305)
Browse files Browse the repository at this point in the history
* add save_freq argument and refactor scripts

* run formatter

* fix test
  • Loading branch information
cpnota authored Feb 25, 2024
1 parent 2ccda12 commit dbcf5ed
Show file tree
Hide file tree
Showing 19 changed files with 211 additions and 176 deletions.
9 changes: 8 additions & 1 deletion all/experiments/parallel_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
logdir="runs",
quiet=False,
render=False,
save_freq=100,
verbose=True,
logger="tensorboard",
):
Expand All @@ -37,6 +38,7 @@ def __init__(
self._preset = preset
self._agent = preset.agent(logger=self._logger, train_steps=train_steps)
self._render = render
self._save_freq = save_freq

# training state
self._returns = []
Expand Down Expand Up @@ -87,9 +89,10 @@ def train(self, frames=np.inf, episodes=np.inf):
for i in range(num_envs):
if dones[i]:
self._log_training_episode(returns[i], episode_lengths[i], fps)
self._save_model()
returns[i] = 0
episode_lengths[i] = 0
self._episode += episodes_completed
self._episode += 1

def test(self, episodes=100):
test_agent = self._preset.parallel_test_agent()
Expand Down Expand Up @@ -144,3 +147,7 @@ def _make_logger(self, logdir, agent_name, env_name, verbose, logger):
return ExperimentLogger(
self, agent_name, env_name, verbose=verbose, logdir=logdir
)

def _save_model(self):
if self._save_freq != float("inf") and self._episode % self._save_freq == 0:
self.save()
2 changes: 1 addition & 1 deletion all/experiments/parallel_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_writes_training_returns_episode(self):
self.experiment.train(episodes=4)
np.testing.assert_equal(
self.experiment._logger.data["eval/returns/episode"]["steps"],
np.array([1, 2, 3, 3]),
np.array([1, 2, 3, 4]),
)
np.testing.assert_equal(
self.experiment._logger.data["eval/returns/episode"]["values"],
Expand Down
2 changes: 2 additions & 0 deletions all/experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def run_experiment(
logdir="runs",
quiet=False,
render=False,
save_freq=100,
test_episodes=100,
verbose=True,
logger="tensorboard",
Expand All @@ -32,6 +33,7 @@ def run_experiment(
logdir=logdir,
quiet=quiet,
render=render,
save_freq=save_freq,
verbose=verbose,
logger=logger,
)
Expand Down
7 changes: 7 additions & 0 deletions all/experiments/single_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
logdir="runs",
quiet=False,
render=False,
save_freq=100,
verbose=True,
logger="tensorboard",
):
Expand All @@ -33,6 +34,7 @@ def __init__(
self._render = render
self._frame = 1
self._episode = 1
self._save_freq = 100

if render:
self._env.render(mode="human")
Expand Down Expand Up @@ -88,6 +90,7 @@ def _run_training_episode(self):

# log the results
self._log_training_episode(returns, episode_length, fps)
self._save_model()

# update experiment state
self._episode += 1
Expand Down Expand Up @@ -121,3 +124,7 @@ def _make_logger(self, logdir, agent_name, env_name, verbose, logger):
return ExperimentLogger(
self, agent_name, env_name, verbose=verbose, logdir=logdir
)

def _save_model(self):
if self._save_freq != float("inf") and self._episode % self._save_freq == 0:
self.save()
63 changes: 0 additions & 63 deletions all/scripts/classic.py

This file was deleted.

75 changes: 0 additions & 75 deletions all/scripts/continuous.py

This file was deleted.

38 changes: 25 additions & 13 deletions all/scripts/atari.py → all/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
import argparse

from all.environments import AtariEnvironment
from all.experiments import run_experiment
from all.presets import atari


def main():
parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
parser.add_argument("env", help="Name of the Atari game (e.g. Pong).")
def train(
presets,
env_constructor,
description="Train an RL agent",
env_help="Name of the environment (e.g., 'CartPole-v0')",
default_frames=40e6,
):
# parse command line args
parser = argparse.ArgumentParser(description=description)
parser.add_argument("env", help=env_help)
parser.add_argument(
"agent", help="Name of the agent (e.g. dqn). See presets for available agents."
"agent",
help="Name of the agent (e.g. 'dqn'). See presets for available agents.",
)
parser.add_argument(
"--device",
default="cuda",
help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).",
)
parser.add_argument(
"--frames", type=int, default=40e6, help="The number of training frames."
"--frames",
type=int,
default=default_frames,
help="The number of training frames.",
)
parser.add_argument(
"--render", action="store_true", default=False, help="Render the environment."
Expand All @@ -28,13 +37,18 @@ def main():
default="tensorboard",
help="The backend used for tracking experiment metrics.",
)
parser.add_argument(
"--save_freq", default=100, help="How often to save the model, in episodes."
)
parser.add_argument("--hyperparameters", default=[], nargs="*")
args = parser.parse_args()

env = AtariEnvironment(args.env, device=args.device)
# construct the environment
env = env_constructor(args.env, device=args.device)

# construct the agents
agent_name = args.agent
agent = getattr(atari, agent_name)
agent = getattr(presets, agent_name)
agent = agent.device(args.device)

# parse hyperparameters
Expand All @@ -44,15 +58,13 @@ def main():
hyperparameters[key] = type(agent.default_hyperparameters[key])(value)
agent = agent.hyperparameters(**hyperparameters)

# run the experiment
run_experiment(
agent,
env,
args.frames,
render=args.render,
logdir=args.logdir,
logger=args.logger,
save_freq=args.save_freq,
)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions all/scripts/train_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from all.environments import AtariEnvironment
from all.presets import atari

from .train import train


def main():
train(
atari,
AtariEnvironment,
description="Train an agent on an Atari environment.",
env_help="The name of the environment (e.g., 'Pong').",
default_frames=40e6,
)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions all/scripts/train_classic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from all.environments import GymEnvironment
from all.presets import classic_control

from .train import train


def main():
train(
classic_control,
GymEnvironment,
description="Train an agent on an classic control environment.",
env_help="The name of the environment (e.g., CartPole-v0).",
default_frames=50000,
)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions all/scripts/train_continuous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from all.environments import GymEnvironment
from all.presets import continuous

from .train import train


def main():
train(
continuous,
GymEnvironment,
description="Train an agent on a continuous control environment.",
env_help="The name of the environment (e.g., LunarLanderContinuous-v2).",
default_frames=10e6,
)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions all/scripts/train_mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from all.environments import MujocoEnvironment
from all.presets import continuous

from .train import train


def main():
train(
continuous,
MujocoEnvironment,
description="Train an agent on an Mujoco environment.",
env_help="The name of the environment (e.g., Ant-v4).",
default_frames=10e6,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit dbcf5ed

Please sign in to comment.