Skip to content

Commit

Permalink
Update tensorflow logging
Browse files Browse the repository at this point in the history
  • Loading branch information
msosav committed Jan 6, 2025
1 parent 08c4529 commit 390530a
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,28 @@ def PreprocessEnv(config: dict) -> VecFrameStack:

class CheckpointAndLoggingCallback(BaseCallback):
def __init__(self, check_freq, save_path, verbose=0):
super(CheckpointAndLoggingCallback, self).__init__(verbose)
self.chek_freq = check_freq
super().__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
self.episode_rewards = []
self.episode_lengths = []

def _init_callback(self) -> None:
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)

def _on_step(self):
if self.n_calls % self.chek_freq == 0:
model_path = os.path.join(self.save_path, f"best_model_{self.n_calls}.zip")
# Log episode info
if self.locals.get('done'):
self.logger.record('game/episode_reward',
self.locals.get('rewards')[0])
self.logger.record('game/episode_length', self.n_calls)
self.logger.record('game/current_health',
self.training_env.get_attr('pyboy')[0].memory[0xDB5A])

# Save model checkpoint
if self.n_calls % self.check_freq == 0:
model_path = f"{self.save_path}/best_model_{self.n_calls}.zip"
self.model.save(model_path)

return True

0 comments on commit 390530a

Please sign in to comment.