diff --git a/dlio_benchmark/data_loader/dali_data_loader.py b/dlio_benchmark/data_loader/dali_data_loader.py index 1c5f7b86..1ded61d3 100644 --- a/dlio_benchmark/data_loader/dali_data_loader.py +++ b/dlio_benchmark/data_loader/dali_data_loader.py @@ -140,7 +140,10 @@ def next(self): self.read(True) while step < self.num_samples // self.batch_size: for pipe in self.pipelines: - outputs = pipe.share_outputs() + try: + outputs = pipe.share_outputs() + except StopIteration: + return logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}") yield outputs step += 1 diff --git a/dlio_benchmark/data_loader/native_dali_data_loader.py b/dlio_benchmark/data_loader/native_dali_data_loader.py index 3755a23d..39ea3198 100644 --- a/dlio_benchmark/data_loader/native_dali_data_loader.py +++ b/dlio_benchmark/data_loader/native_dali_data_loader.py @@ -59,9 +59,12 @@ def next(self): for pipeline in self.pipelines: pipeline.reset() for step in range(num_samples // batch_size): - for batch in self._dataset: - logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") - yield batch + try: + for batch in self._dataset: + logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") + yield batch + except StopIteration: + return self.epoch_number += 1 dlp.update(epoch=self.epoch_number) @dlp.log diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 0580d7e4..981f9c51 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -257,6 +257,12 @@ def _train(self, epoch): loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) t0 = time() for batch in dlp.iter(loader.next()): + if overall_step > max_steps or ((self.total_training_steps > 0) and (overall_step > self.total_training_steps)): + if self.args.my_rank == 0: + logging.info(f"{utcnow()} Maximum number of steps reached") + if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint): + self.stats.end_block(epoch, block, block_step - 1) + break self.stats.batch_loaded(epoch, overall_step, block, t0) # Log a new block, unless it's the first one which we've already logged before the loop if block_step == 1 and block != 1: @@ -283,13 +289,6 @@ def _train(self, epoch): self.next_checkpoint_step += self.steps_between_checkpoints else: block_step += 1 - - if overall_step >= max_steps or overall_step == self.total_training_steps: - if self.args.my_rank == 0: - logging.info(f"{utcnow()} Maximum number of steps reached") - if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint): - self.stats.end_block(epoch, block, block_step - 1) - break overall_step += 1 t0 = time() self.comm.barrier()