Skip to content

Commit

Permalink
fix validation loss computation
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 8, 2024
1 parent 57a5c01 commit 5419d42
Showing 1 changed file with 57 additions and 21 deletions.
78 changes: 57 additions & 21 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ def _train_one_epoch(self):
# synchronize gradients for the last batch
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)
loss = (loss + sum(loss_dict.values()))
loss = loss + sum(loss_dict.values())
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size
self.grad_scaler.scale(loss).backward()
Expand Down Expand Up @@ -1335,29 +1335,65 @@ def _evaluate_and_checkpoint(self):
metrics.append(metric)

start = time.perf_counter()
for batch_num, batch in enumerate(self.val_loader):
inputs, labels = self._prepare_batch(batch)

with torch.autocast(
"cuda",
dtype=self.mixed_precision,
enabled=self.mixed_precision is not None
), torch.no_grad():
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)
val_iter = iter(self.val_loader)
min_num_batches = torch.zeros(1, dtype=torch.long, device=self.info.device)
logged = False
while True:
batches = []
for _ in range(self.gradient_accumulation_steps):
batch = next(val_iter, None)
if batch is None:
break
elif len(batch) == 0: # type: ignore
raise RuntimeError(
"got empty batch, this should not happen during evaluation"
)

batches.append(batch)

min_num_batches[0] = len(batches)
dist.all_reduce(min_num_batches, op=dist.ReduceOp.MIN)
batches = batches[:min_num_batches.item()]
min_num_batches[0] = 0

if len(batches) == 0:
break

rank_batch_size = sum(len(batch) for batch in batches)
all_outputs = []
losses = []
for batch in batches:
inputs, labels = self._prepare_batch(batch)

with torch.autocast(
"cuda",
dtype=self.mixed_precision,
enabled=self.mixed_precision is not None
), torch.no_grad():
outputs, loss_dict = self.model(**inputs)
loss = self.loss_fn(outputs, labels)

loss = loss + sum(loss_dict.values())
if self.gradient_accumulation_reduction == "mean":
loss = loss * len(batch) / rank_batch_size

mean_loss.add(loss.item())
losses.append(loss.item())
dist.barrier()
if not logged:
all_outputs.append((batch.items(), outputs))

if batch_num == 0 and self.info.is_main_process:
items = batch.items()
for metric in metrics:
metric.set_values(items, outputs)
metric.log_tensorboard(
self.summary_writer,
self.total_step
)
metric.log_info(self.logger, self.total_step)
mean_loss.add(sum(losses))

if self.info.is_main_process:
for items, outputs in all_outputs:
for metric in metrics:
metric.set_values(items, outputs)
metric.log_tensorboard(
self.summary_writer,
self.total_step
)
metric.log_info(self.logger, self.total_step)
logged = True

end = time.perf_counter()
mean_loss.sync()
Expand Down

0 comments on commit 5419d42

Please sign in to comment.