From a12a828ede58d289cd8d3789264713f11c263a37 Mon Sep 17 00:00:00 2001
From: Chris Nota <cpnota@gmail.com>
Date: Tue, 5 Mar 2024 11:39:38 -0500
Subject: [PATCH] fix plotter and log final summary at end of training (#320)

---
 all/experiments/parallel_env_experiment.py | 2 ++
 all/experiments/plots.py                   | 2 +-
 all/experiments/single_env_experiment.py   | 2 ++
 3 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/all/experiments/parallel_env_experiment.py b/all/experiments/parallel_env_experiment.py
index 7e38139b..33d975ac 100644
--- a/all/experiments/parallel_env_experiment.py
+++ b/all/experiments/parallel_env_experiment.py
@@ -92,6 +92,8 @@ def train(self, frames=np.inf, episodes=np.inf):
                         returns[i] = 0
                         episode_lengths[i] = -1
                         self._episode += 1
+        if len(self._returns100) > 0:
+            self._logger.add_summary("returns100", self._returns100)
 
     def test(self, episodes=100):
         test_agent = self._preset.parallel_test_agent()
diff --git a/all/experiments/plots.py b/all/experiments/plots.py
index 400c6266..579b16a6 100644
--- a/all/experiments/plots.py
+++ b/all/experiments/plots.py
@@ -23,7 +23,7 @@ def load_returns_100_data(runs_dir):
     def add_data(agent, env, file):
         if env not in data:
             data[env] = {}
-        data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 3))
+        data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 5))
 
     for agent_dir in os.listdir(runs_dir):
         agent, env, *_ = agent_dir.split("_")
diff --git a/all/experiments/single_env_experiment.py b/all/experiments/single_env_experiment.py
index f4ad1cb8..53e152d4 100644
--- a/all/experiments/single_env_experiment.py
+++ b/all/experiments/single_env_experiment.py
@@ -49,6 +49,8 @@ def episode(self):
     def train(self, frames=np.inf, episodes=np.inf):
         while not self._done(frames, episodes):
             self._run_training_episode()
+        if len(self._returns100) > 0:
+            self._logger.add_summary("returns100", self._returns100)
 
     def test(self, episodes=100):
         test_agent = self._preset.test_agent()