diff --git a/pyproject.toml b/pyproject.toml index 9561280..476c930 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ [project.optional-dependencies] training = [ "matplotlib==3.8.2", + "moviepy==1.0.3", "pandas==2.2.0", "pyarrow==15.0.0", "tqdm==4.66.1", diff --git a/scripts/enjoy.py b/scripts/enjoy.py index 7675dd6..064e9c7 100644 --- a/scripts/enjoy.py +++ b/scripts/enjoy.py @@ -105,14 +105,16 @@ def enjoy() -> None: args = parse_args() np.random.seed(args.seed) - env = gym.make(args.env, render_mode="human") if args.record_video: + env = gym.make(args.env, render_mode="rgb_array") env = gym.wrappers.RecordVideo( env, video_folder=args.video_folder, episode_trigger=lambda _: True, disable_logger=True, ) + else: + env = gym.make(args.env, render_mode="human") policy = make_policy(algo=args.algo, trained_agent=args.trained_agent)