From f123cda8f8e7ec1df408a96ec2019fdeea74901a Mon Sep 17 00:00:00 2001 From: Ray Andrew Date: Wed, 16 Oct 2024 04:19:26 +0000 Subject: [PATCH 1/3] fix last step is not executed If we have 730 steps, DLIO benchmark only executes until 729 The bug also persists when user specified `total_training_steps` --- dlio_benchmark/main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index dd91fa0f..9b3d2d9a 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -257,6 +257,12 @@ def _train(self, epoch): loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) t0 = time() for batch in dlp.iter(loader.next()): + if overall_step > max_steps or overall_step > self.total_training_steps: + if self.args.my_rank == 0: + logging.info(f"{utcnow()} Maximum number of steps reached") + if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint): + self.stats.end_block(epoch, block, block_step - 1) + break self.stats.batch_loaded(epoch, overall_step, block, t0) # Log a new block, unless it's the first one which we've already logged before the loop if block_step == 1 and block != 1: @@ -283,13 +289,6 @@ def _train(self, epoch): self.next_checkpoint_step += self.steps_between_checkpoints else: block_step += 1 - - if overall_step >= max_steps or overall_step == self.total_training_steps: - if self.args.my_rank == 0: - logging.info(f"{utcnow()} Maximum number of steps reached") - if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint): - self.stats.end_block(epoch, block, block_step - 1) - break overall_step += 1 t0 = time() self.comm.barrier() From d352b5d61634230aa1fb2b35d768fe4b3e84f948 Mon Sep 17 00:00:00 2001 From: Ray Andrew Date: Wed, 16 Oct 2024 15:25:38 +0000 Subject: [PATCH 2/3] fix bug when `total_training_steps` is not specified If `total_training_steps` is not specified, the default will be -1. Thus checking whether it is > 0 is needed --- dlio_benchmark/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 9b3d2d9a..f16e488a 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -257,7 +257,7 @@ def _train(self, epoch): loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) t0 = time() for batch in dlp.iter(loader.next()): - if overall_step > max_steps or overall_step > self.total_training_steps: + if overall_step > max_steps or ((self.total_training_steps > 0) and (overall_step > self.total_training_steps)): if self.args.my_rank == 0: logging.info(f"{utcnow()} Maximum number of steps reached") if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint): From 9185195f844c4d448beace8c4cdac25d816134ab Mon Sep 17 00:00:00 2001 From: Ray Andrew Date: Wed, 16 Oct 2024 18:17:19 +0000 Subject: [PATCH 3/3] add exception for dali --- dlio_benchmark/data_loader/dali_data_loader.py | 5 ++++- dlio_benchmark/data_loader/native_dali_data_loader.py | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/dlio_benchmark/data_loader/dali_data_loader.py b/dlio_benchmark/data_loader/dali_data_loader.py index 1c5f7b86..1ded61d3 100644 --- a/dlio_benchmark/data_loader/dali_data_loader.py +++ b/dlio_benchmark/data_loader/dali_data_loader.py @@ -140,7 +140,10 @@ def next(self): self.read(True) while step < self.num_samples // self.batch_size: for pipe in self.pipelines: - outputs = pipe.share_outputs() + try: + outputs = pipe.share_outputs() + except StopIteration: + return logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}") yield outputs step += 1 diff --git a/dlio_benchmark/data_loader/native_dali_data_loader.py b/dlio_benchmark/data_loader/native_dali_data_loader.py index 3755a23d..39ea3198 100644 --- a/dlio_benchmark/data_loader/native_dali_data_loader.py +++ b/dlio_benchmark/data_loader/native_dali_data_loader.py @@ -59,9 +59,12 @@ def next(self): for pipeline in self.pipelines: pipeline.reset() for step in range(num_samples // batch_size): - for batch in self._dataset: - logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") - yield batch + try: + for batch in self._dataset: + logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") + yield batch + except StopIteration: + return self.epoch_number += 1 dlp.update(epoch=self.epoch_number) @dlp.log