Skip to content

Commit

Permalink
fix nan loss issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 8, 2024
1 parent 5419d42 commit f154d31
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,21 @@ def _train_one_epoch(self):
mean_batch_load.add((end_batch - start_batch) * 1000)
mean_item_size_ratio.add(max_size / max(1, min_size))

def step(
batch: data.TrainBatch,
inputs: dict[str, Any],
labels: torch.Tensor
) -> tuple[torch.Tensor, float]:
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)
loss = loss + sum(loss_dict.values())
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size
if loss.isnan():
loss = torch.zeros_like(loss, requires_grad=True)
self.grad_scaler.scale(loss).backward()
return outputs.detach(), loss.item()

first_outputs = None
losses = []
for i, batch in enumerate(batches):
Expand All @@ -1119,24 +1134,14 @@ def _train_one_epoch(self):
):
if i < len(batches) - 1:
with self.model.no_sync():
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)
loss = loss + sum(loss_dict.values())
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size
self.grad_scaler.scale(loss).backward()
outputs, loss = step(batch, inputs, labels)
else:
# synchronize gradients for the last batch
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)
loss = loss + sum(loss_dict.values())
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size
self.grad_scaler.scale(loss).backward()
outputs, loss = step(batch, inputs, labels)

losses.append(loss.item())
losses.append(loss)
if first_outputs is None:
first_outputs = outputs.detach()
first_outputs = outputs

dist.barrier()

Expand Down Expand Up @@ -1374,6 +1379,8 @@ def _evaluate_and_checkpoint(self):
loss = self.loss_fn(outputs, labels)

loss = loss + sum(loss_dict.values())
if loss.isnan():
loss = torch.zeros_like(loss)
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size

Expand Down

0 comments on commit f154d31

Please sign in to comment.