From c6338116fb4af3cf9573d3f68613867076d8e3d3 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Wed, 26 Jun 2024 23:00:05 +0200 Subject: [PATCH] fix uneven batch nums accross ranks --- python/text_utils/api/trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/text_utils/api/trainer.py b/python/text_utils/api/trainer.py index a3507a9..fb75fcf 100644 --- a/python/text_utils/api/trainer.py +++ b/python/text_utils/api/trainer.py @@ -1036,6 +1036,7 @@ def _train_one_epoch(self): self.info.device, output_op="sum" ) + min_num_batches = torch.zeros(1, dtype=torch.long, device=self.info.device) metrics = [] for name, cfg in self.cfg["train"].get("metrics", {}).items(): @@ -1075,6 +1076,10 @@ def _train_one_epoch(self): batches.append(batch) end_batch = time.perf_counter() + min_num_batches[0] = len(batches) + dist.all_reduce(min_num_batches, op=dist.ReduceOp.MIN) + batches = batches[:min_num_batches.item()] + min_num_batches[0] = 0 if len(batches) == 0: self.logger.info( @@ -1123,6 +1128,8 @@ def _train_one_epoch(self): if first_outputs is None: first_outputs = outputs.detach() + dist.barrier() + if self.clip_gradient_norm is not None: self.grad_scaler.unscale_(self.optimizer) if isinstance(self.model, FSDP):