diff --git a/examples/atari/reproduction/dqn/train_dqn.py b/examples/atari/reproduction/dqn/train_dqn.py index 72c210ad5..6402201bb 100644 --- a/examples/atari/reproduction/dqn/train_dqn.py +++ b/examples/atari/reproduction/dqn/train_dqn.py @@ -31,6 +31,12 @@ def main(): " If it does not exist, it will be created." ), ) + parser.add_argument( + "--exp-id", + type=str, + default=None, + help="Experiment ID. If None, commit hash or timestamp is used.", + ) parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") parser.add_argument( "--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU." @@ -73,9 +79,22 @@ def main(): default=5 * 10 ** 4, help="Minimum replay buffer size before " + "performing gradient updates.", ) + parser.add_argument( + "--save-snapshot", + action="store_true", + default=False, + help="Take resumable snapshot at every checkpoint", + ) + parser.add_argument( + "--load-snapshot", + action="store_true", + default=False, + help="Load snapshot if exists", + ) parser.add_argument("--eval-n-steps", type=int, default=125000) parser.add_argument("--eval-interval", type=int, default=250000) parser.add_argument("--n-best-episodes", type=int, default=30) + parser.add_argument("--checkpoint-freq", type=int, default=2000000) args = parser.parse_args() import logging @@ -89,7 +108,9 @@ def main(): train_seed = args.seed test_seed = 2 ** 31 - 1 - args.seed - args.outdir = experiments.prepare_output_dir(args, args.outdir) + args.outdir = experiments.prepare_output_dir( + args, args.outdir, exp_id=args.exp_id, make_backup=args.exp_id is None + ) print("Output files are saved in {}".format(args.outdir)) def make_env(test): @@ -162,6 +183,17 @@ def phi(x): phi=phi, ) + # load snapshot + step_offset, episode_offset = 0, 0 + max_score = None + if args.load_snapshot: + snapshot_dirname = experiments.latest_snapshot_dir(args.outdir) + if snapshot_dirname: + print(f"load snapshot from {snapshot_dirname}") + step_offset, episode_offset, max_score = experiments.load_snapshot( + agent, snapshot_dirname + ) + if args.load or args.load_pretrained: # either load or load_pretrained must be false assert not args.load or not args.load_pretrained @@ -194,6 +226,11 @@ def phi(x): eval_n_steps=args.eval_n_steps, eval_n_episodes=None, eval_interval=args.eval_interval, + step_offset=step_offset, + episode_offset=episode_offset, + max_score=max_score, + checkpoint_freq=args.checkpoint_freq, + take_resumable_snapshot=args.save_snapshot, outdir=args.outdir, save_best_so_far_agent=True, eval_env=eval_env, diff --git a/examples_tests/atari/reproduction/test_dqn.sh b/examples_tests/atari/reproduction/test_dqn.sh index faa0e9368..3e99633ac 100644 --- a/examples_tests/atari/reproduction/test_dqn.sh +++ b/examples_tests/atari/reproduction/test_dqn.sh @@ -10,3 +10,11 @@ gpu="$1" python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu model=$(find $outdir/atari/reproduction/dqn -name "*_finish") python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu + +# snapshot without eval +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --save-snapshot --checkpoint-freq 45 +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 0 --load-snapshot + +# snapshot after eval +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4600 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --save-snapshot --checkpoint-freq 4000 +python examples/atari/reproduction/dqn/train_dqn.py --env PongNoFrameskip-v4 --steps 4700 --replay-start-size 50 --outdir $outdir/atari/reproduction/dqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu --exp-id 1 --load-snapshot diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index e4e79e49d..20f307aaf 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -7,7 +7,11 @@ from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA from pfrl.experiments.train_agent import train_agent # NOQA -from pfrl.experiments.train_agent import train_agent_with_evaluation # NOQA +from pfrl.experiments.train_agent import ( # NOQA + latest_snapshot_dir, + load_snapshot, + train_agent_with_evaluation, +) from pfrl.experiments.train_agent_async import train_agent_async # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index 75691784c..12664a211 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -384,7 +384,10 @@ def write_header(outdir, agent, env): "max", # maximum value of returns of evaluation runs "min", # minimum value of returns of evaluation runs ) - with open(os.path.join(outdir, "scores.txt"), "w") as f: + fp = os.path.join(outdir, "scores.txt") + if os.path.exists(fp) and os.stat(fp).st_size > 0: + return + with open(fp, "w") as f: custom_columns = tuple(t[0] for t in agent.get_statistics()) env_get_stats = getattr(env, "get_statistics", lambda: []) assert callable(env_get_stats) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 210c7ed24..e89031413 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -1,5 +1,8 @@ +import csv import logging import os +import shutil +import time from pfrl.experiments.evaluator import Evaluator, save_agent from pfrl.utils.ask_yes_no import ask_yes_no @@ -21,14 +24,96 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): save_agent_replay_buffer(agent, t, outdir, suffix=suffix) +def snapshot( + agent, + t, + episode_idx, + outdir, + suffix="_snapshot", + logger=None, + delete_old=True, +): + start_time = time.time() + tmp_suffix = f"{suffix}_" + tmp_dirname = os.path.join(outdir, f"{t}{tmp_suffix}") # use until files are saved + agent.save(tmp_dirname) + if hasattr(agent, "replay_buffer"): + agent.replay_buffer.save(os.path.join(tmp_dirname, "replay.pkl")) + if os.path.exists(os.path.join(outdir, "scores.txt")): + shutil.copyfile( + os.path.join(outdir, "scores.txt"), os.path.join(tmp_dirname, "scores.txt") + ) + + history_path = os.path.join(outdir, "snapshot_history.txt") + if not os.path.exists(history_path): # write header + with open(history_path, "a") as f: + csv.writer(f, delimiter="\t").writerow(["step", "episode", "snapshot_time"]) + with open(history_path, "a") as f: + csv.writer(f, delimiter="\t").writerow( + [t, episode_idx, time.time() - start_time] + ) + shutil.copyfile(history_path, os.path.join(tmp_dirname, "snapshot_history.txt")) + + real_dirname = os.path.join(outdir, f"{t}{suffix}") + os.rename(tmp_dirname, real_dirname) + if logger: + logger.info(f"Saved the snapshot to {real_dirname}") + if delete_old: + for old_dir in filter( + lambda s: s.endswith(suffix) or s.endswith(tmp_suffix), os.listdir(outdir) + ): + if old_dir != f"{t}{suffix}": + shutil.rmtree(os.path.join(outdir, old_dir)) + + +def load_snapshot(agent, dirname, logger=None): + agent.load(dirname) + if hasattr(agent, "replay_buffer"): + agent.replay_buffer.load(os.path.join(dirname, "replay.pkl")) + if logger: + logger.info(f"Loaded the snapshot from {dirname}") + with open(os.path.join(dirname, "snapshot_history.txt")) as f: + step, episode = map(int, f.readlines()[-1].split()[:2]) + max_score = None + if os.path.exists(os.path.join(dirname, "scores.txt")): + with open(os.path.join(dirname, "scores.txt")) as f: + lines = f.readlines() + if len(lines) > 1: + max_score = float(lines[-1].split()[3]) # mean + shutil.copyfile( + os.path.join(dirname, "snapshot_history.txt"), + os.path.join(dirname, "..", "snapshot_history.txt"), + ) + shutil.copyfile( + os.path.join(dirname, "scores.txt"), + os.path.join(dirname, "..", "scores.txt"), + ) + return step, episode, max_score + + +def latest_snapshot_dir(search_dir, suffix="_snapshot"): + """ + return None if no snapshot exists + """ + candidates = list(filter(lambda s: s.endswith(suffix), os.listdir(search_dir))) + if len(candidates) == 0: + return None + return os.path.join( + search_dir, max(candidates, key=lambda name: int(name.split("_")[0])) + ) + + def train_agent( agent, env, steps, outdir, checkpoint_freq=None, + take_resumable_snapshot=False, max_episode_len=None, step_offset=0, + episode_offset=0, + max_score=None, evaluator=None, successful_score=None, step_hooks=(), @@ -38,8 +123,12 @@ def train_agent( logger = logger or logging.getLogger(__name__) + # restore max_score + if evaluator and max_score: + evaluator.max_score = max_score + episode_r = 0 - episode_idx = 0 + episode_idx = episode_offset # o_0, r_0 obs = env.reset() @@ -100,7 +189,10 @@ def train_agent( episode_len = 0 obs = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: - save_agent(agent, t, outdir, logger, suffix="_checkpoint") + if take_resumable_snapshot: + snapshot(agent, t, episode_idx, outdir, logger=logger) + else: + save_agent(agent, t, outdir, logger, suffix="_checkpoint") except (Exception, KeyboardInterrupt): # Save the current model before being killed @@ -122,9 +214,12 @@ def train_agent_with_evaluation( eval_interval, outdir, checkpoint_freq=None, + take_resumable_snapshot=False, train_max_episode_len=None, step_offset=0, + episode_offset=0, eval_max_episode_len=None, + max_score=None, eval_env=None, successful_score=None, step_hooks=(), @@ -144,11 +239,18 @@ def train_agent_with_evaluation( eval_n_episodes (int): Number of episodes at each evaluation phase. eval_interval (int): Interval of evaluation. outdir (str): Path to the directory to output data. - checkpoint_freq (int): frequency at which agents are stored. + checkpoint_freq (int): frequency in step at which agents are stored. + take_resumable_snapshot (bool): If True, snapshot is saved in checkpoint. + Note that currently, snapshot does not support agent analytics (e.g., + for DQN, average_q, average_loss, cumulative_steps, and n_updates) and + those valued in "scores.txt" might be incorrect after resuming from + snapshot. train_max_episode_len (int): Maximum episode length during training. step_offset (int): Time step from which training starts. + episode_offset (int): Episode index from which training starts, eval_max_episode_len (int or None): Maximum episode length of evaluation runs. If None, train_max_episode_len is used instead. + max_score (int): Current max socre. eval_env: Environment used for evaluation. successful_score (float): Finish training if the mean score is greater than or equal to this value if not None @@ -211,8 +313,11 @@ def train_agent_with_evaluation( steps, outdir, checkpoint_freq=checkpoint_freq, + take_resumable_snapshot=take_resumable_snapshot, max_episode_len=train_max_episode_len, step_offset=step_offset, + episode_offset=episode_offset, + max_score=max_score, evaluator=evaluator, successful_score=successful_score, step_hooks=step_hooks,