From a12a828ede58d289cd8d3789264713f11c263a37 Mon Sep 17 00:00:00 2001 From: Chris Nota Date: Tue, 5 Mar 2024 11:39:38 -0500 Subject: [PATCH] fix plotter and log final summary at end of training (#320) --- all/experiments/parallel_env_experiment.py | 2 ++ all/experiments/plots.py | 2 +- all/experiments/single_env_experiment.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/all/experiments/parallel_env_experiment.py b/all/experiments/parallel_env_experiment.py index 7e38139b..33d975ac 100644 --- a/all/experiments/parallel_env_experiment.py +++ b/all/experiments/parallel_env_experiment.py @@ -92,6 +92,8 @@ def train(self, frames=np.inf, episodes=np.inf): returns[i] = 0 episode_lengths[i] = -1 self._episode += 1 + if len(self._returns100) > 0: + self._logger.add_summary("returns100", self._returns100) def test(self, episodes=100): test_agent = self._preset.parallel_test_agent() diff --git a/all/experiments/plots.py b/all/experiments/plots.py index 400c6266..579b16a6 100644 --- a/all/experiments/plots.py +++ b/all/experiments/plots.py @@ -23,7 +23,7 @@ def load_returns_100_data(runs_dir): def add_data(agent, env, file): if env not in data: data[env] = {} - data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 3)) + data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 5)) for agent_dir in os.listdir(runs_dir): agent, env, *_ = agent_dir.split("_") diff --git a/all/experiments/single_env_experiment.py b/all/experiments/single_env_experiment.py index f4ad1cb8..53e152d4 100644 --- a/all/experiments/single_env_experiment.py +++ b/all/experiments/single_env_experiment.py @@ -49,6 +49,8 @@ def episode(self): def train(self, frames=np.inf, episodes=np.inf): while not self._done(frames, episodes): self._run_training_episode() + if len(self._returns100) > 0: + self._logger.add_summary("returns100", self._returns100) def test(self, episodes=100): test_agent = self._preset.test_agent()