diff --git a/kerastuner_tensorboard_logger/logger.py b/kerastuner_tensorboard_logger/logger.py index efd54b0..21fd9d7 100644 --- a/kerastuner_tensorboard_logger/logger.py +++ b/kerastuner_tensorboard_logger/logger.py @@ -7,7 +7,7 @@ def timedelta_to_hms(timedelta: timedelta) -> str: - """convert datetime.timedelta to string like '01h:15m:30s'""" + """(Deprecated) convert datetime.timedelta to string like '01h:15m:30s'""" tot_seconds = int(timedelta.total_seconds()) hours = tot_seconds // 3600 minutes = (tot_seconds % 3600) // 60 @@ -30,7 +30,7 @@ class TensorBoardLogger(Logger): def __init__( self, metrics: Union[str, List[str]] = ["acc"], - logdir: str = "logs/hparam_tuning", + logdir: str = "logs/", overwrite: bool = False, ): self.metrics = [metrics] if isinstance(metrics, str) else metrics @@ -46,18 +46,16 @@ def register_tuner(self, tuner_state): def register_trial(self, trial_id: str, trial_state: Dict[str, Any]): """Informs the logger that a new Trial is starting.""" - self.times[trial_id] = datetime.now() + pass def report_trial_state(self, trial_id: str, trial_state: Dict[str, Any]): """Gives the logger information about trial status.""" - execution_time = timedelta_to_hms(datetime.now() - self.times.pop(trial_id)) - name = f"{execution_time}-{trial_id}" - logdir = os.path.join(self.logdir, name) + logdir = os.path.join(self.logdir, trial_id, "hparams") with tf.summary.create_file_writer(logdir).as_default(): hparams = self.parse_hparams(trial_state) hp_board.hparams( - hparams, trial_id=name + hparams, trial_id=trial_id ) # record the values used in this trial for target_metric, metric in self.parse_metrics(trial_state): diff --git a/scripts/local_test.sh b/scripts/local_test.sh index 709ca4d..c83bea0 100755 --- a/scripts/local_test.sh +++ b/scripts/local_test.sh @@ -3,4 +3,4 @@ set -e pytest --disable-warnings tests/ -tensorboard --logdir tests/logs/hparams \ No newline at end of file +tensorboard --logdir tests/logs \ No newline at end of file diff --git a/tests/test_logger.py b/tests/test_logger.py index fc18251..d7da718 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -12,6 +12,7 @@ def test_timedelta_to_hms(): + """(Deprecated)""" td = timedelta(minutes=10, hours=2, seconds=30, microseconds=111) out = timedelta_to_hms(td) assert out == "2h10m30s" @@ -198,6 +199,35 @@ def test_initialize_manual(): tuner.search(train_data, epochs=3, validation_data=test_data) +def test_search_with_callbacks_manual(): + """test logging with TensorBoardCallbacks + manual test is required. log files for tensorboard, + then, run tensorboard server as: + + ```bash + tensorboard --logdir tests/logs/with-callbacks + ``` + + """ + tuner = Hyperband( + build_model, + objective="val_acc", + max_epochs=3, + directory="tests/logs/with-callbacks/search", + project_name="initialize_manual", + overwrite=True, + logger=TensorBoardLogger( + metrics="val_acc", + logdir="tests/logs/with-callbacks", + overwrite=True, + ), + ) + setup_tb(tuner) + train_data, test_data = make_dataset() + callbacks = [tf.keras.callbacks.TensorBoard(log_dir="tests/logs/with-callbacks")] + tuner.search(train_data, epochs=3, validation_data=test_data, callbacks=callbacks) + + def test_parse(): trained_trial_state = { "trial_id": "bb0649bfdb92155d308f12dca83152e1",