Skip to content

Commit

Permalink
make metrics a list
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 9, 2024
1 parent 8310b66 commit df3e3fb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,8 +1044,8 @@ def _train_one_epoch(self):
1, dtype=torch.long, device=self.info.device)

metrics = []
for name, cfg in self.cfg["train"].get("metrics", {}).items():
metric = self._metric_from_config(name, cfg, "train")
for metric_cfg in self.cfg["train"].get("metrics", []):
metric = self._metric_from_config(metric_cfg, "train")
metrics.append(metric)

start_items = self.epoch_items
Expand Down Expand Up @@ -1333,10 +1333,9 @@ def _evaluate_and_checkpoint(self):
self.loss_fn = self.loss_fn.eval()

metrics = []
for name, cfg in self.cfg["val"].get("metrics", {}).items():
for metric_cfg in self.cfg["val"].get("metrics", []):
metric = self._metric_from_config(
name,
cfg,
metric_cfg,
"val"
)
metrics.append(metric)
Expand Down

0 comments on commit df3e3fb

Please sign in to comment.