From 3f8cb0bd1851e60abb0784ef56433c546e6fb8c4 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Sat, 19 May 2018 18:33:22 +0200 Subject: [PATCH] Fix logic with training resumption (#404) * Fix logic with training resumption * fix --- CHANGELOG.md | 7 ++++++- sockeye/data_io.py | 3 --- sockeye/training.py | 3 --- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76aeedb6b..207ab3e7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. ## [1.18.13] + +### Fixed +- Fixed two bugs with training resumption: + 1. removed overly strict assertion in the data iterator for model states before the first checkpoint. + 2. removed deletion of Tensorboard log directory. + ### Added - Added support for config files. Command line parameters have precedence over the values read from the config file. Minimal working example: @@ -22,7 +28,6 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed validation_source: valid.source.txt validation_target: valid.target.txt ``` - ### Changed The full set of arguments is serialized to `out/args.yaml` at the beginning of training (before json was used). diff --git a/sockeye/data_io.py b/sockeye/data_io.py index d3fad5de9..86b42725c 100644 --- a/sockeye/data_io.py +++ b/sockeye/data_io.py @@ -1537,9 +1537,6 @@ def load_state(self, fname: str): inverse_data_permutations = np.load(fp) data_permutations = np.load(fp) - # Because of how checkpointing is done (pre-fetching the next batch in - # each iteration), curr_idx should always be >= 1 - assert self.curr_batch_index >= 1 # Right after loading the iterator state, next() should be called self.curr_batch_index -= 1 diff --git a/sockeye/training.py b/sockeye/training.py index 67a26d98e..a8f73f087 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -1017,9 +1017,6 @@ def __init__(self, try: import mxboard logger.info("Logging training events for Tensorboard at '%s'", self.logdir) - if os.path.exists(self.logdir): - logger.info("Deleting existing Tensorboard log directory '%s'", self.logdir) - shutil.rmtree(self.logdir) self.sw = mxboard.SummaryWriter(logdir=self.logdir, flush_secs=60, verbose=False) except ImportError: logger.info("mxboard not found. Consider 'pip install mxboard' to log events to Tensorboard.")