Skip to content

Commit

Permalink
Merge pull request #34 from roboflow/feature/foundations_of_training_…
Browse files Browse the repository at this point in the history
…dump_metrics

new generic metrics plot system
  • Loading branch information
SkalskiP authored Sep 5, 2024
2 parents 50b4876 + 7ab38eb commit e0eca6b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 53 deletions.
97 changes: 97 additions & 0 deletions maestro/trainer/common/utils/metrics_tracing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

import json
import os
from collections import defaultdict
from typing import Dict, Tuple, List

import matplotlib.pyplot as plt


class MetricsTracker:

Expand All @@ -26,3 +31,95 @@ def get_metric_values(
if with_index:
return self._metrics[metric]
return [value[2] for value in self._metrics[metric]]

def as_json(
self,
output_dir: str = None,
filename: str = None
) -> Dict[str, List[Dict[str, float]]]:
metrics_data = {}
for metric, values in self._metrics.items():
metrics_data[metric] = [
{'epoch': epoch, 'step': step, 'value': value}
for epoch, step, value
in values
]

if output_dir and filename:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w') as file:
json.dump(metrics_data, file, indent=4)

return metrics_data


def aggregate_by_epoch(metric_values: List[Tuple[int, int, float]]) -> Dict[int, float]:
epoch_data = defaultdict(list)
for epoch, step, value in metric_values:
epoch_data[epoch].append(value)
avg_per_epoch = {
epoch: sum(values) / len(values)
for epoch, values
in epoch_data.items()
}
return avg_per_epoch


def save_metric_plots(
training_tracker: MetricsTracker,
validation_tracker: MetricsTracker,
output_dir: str
):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

training_metrics = training_tracker.describe_metrics()
validation_metrics = validation_tracker.describe_metrics()
all_metrics = set(training_metrics + validation_metrics)

for metric in all_metrics:
plt.figure(figsize=(8, 6))

if metric in training_metrics:
training_values = training_tracker.get_metric_values(
metric=metric, with_index=True)
training_avg_values = aggregate_by_epoch(training_values)
training_epochs = sorted(training_avg_values.keys())
training_vals = [training_avg_values[epoch] for epoch in training_epochs]
plt.plot(
training_epochs,
training_vals,
label=f'Training {metric}',
marker='o',
linestyle='-',
color='blue'
)

if metric in validation_metrics:
validation_values = validation_tracker.get_metric_values(
metric=metric, with_index=True)
validation_avg_values = aggregate_by_epoch(validation_values)
validation_epochs = sorted(validation_avg_values.keys())
validation_vals = [
validation_avg_values[epoch]
for epoch
in validation_epochs
]
plt.plot(
validation_epochs,
validation_vals,
label=f'Validation {metric}',
marker='o',
linestyle='--',
color='orange'
)

plt.title(f'{metric.capitalize()} over Epochs')
plt.xlabel('Epoch')
plt.ylabel(f'{metric.capitalize()} Value')
plt.legend()
plt.grid(True)
plt.savefig(f'{output_dir}/{metric}_plot.png')
plt.close()
43 changes: 0 additions & 43 deletions maestro/trainer/models/florence_2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
from typing import List, Tuple

import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
Expand All @@ -13,7 +12,6 @@

from maestro.trainer.common.data_loaders.datasets import DetectionDataset
from maestro.trainer.common.utils.file_system import save_json
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker
from maestro.trainer.models.florence_2.data_loading import prepare_detection_dataset


Expand Down Expand Up @@ -188,44 +186,3 @@ def dump_visualised_samples(
concatenated = cv2.hconcat([target_image, prediction_image])
target_image_path = os.path.join(target_dir, image_name)
cv2.imwrite(target_image_path, concatenated)


def summarise_training_metrics(
training_metrics_tracker: MetricsTracker,
validation_metrics_tracker: MetricsTracker,
training_dir: str,
) -> None:
summarise_metrics(metrics_tracker=training_metrics_tracker, training_dir=training_dir, split_name="train")
summarise_metrics(metrics_tracker=validation_metrics_tracker, training_dir=training_dir, split_name="valid")


def summarise_metrics(
metrics_tracker: MetricsTracker,
training_dir: str,
split_name: str,
) -> None:
plots_dir_path = os.path.join(training_dir, "metrics", split_name)
os.makedirs(plots_dir_path, exist_ok=True)
for metric_name in metrics_tracker.describe_metrics():
plot_path = os.path.join(plots_dir_path, f"metric_{metric_name}_plot.png")
plt.clf()
metric_values_with_index = metrics_tracker.get_metric_values(
metric=metric_name,
with_index=True,
)
xs = np.arange(0, len(metric_values_with_index))
xticks_xs, xticks_labels = [], []
previous = None
for v, x in zip(metric_values_with_index, xs):
if v[0] != previous:
xticks_xs.append(x)
xticks_labels.append(v[0])
previous = v[0]
ys = [e[2] for e in metric_values_with_index]
plt.scatter(xs, ys, marker="x")
plt.plot(xs, ys, linestyle="dashed", linewidth=0.3)
plt.title(f"Value of {metric_name} for {split_name} set")
plt.xticks(xticks_xs, labels=xticks_labels)
plt.xlabel("Epochs")
plt.savefig(plot_path, dpi=120)
plt.clf()
26 changes: 18 additions & 8 deletions maestro/trainer/models/florence_2/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler

from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE
from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, \
DEFAULT_CUDA_DEVICE
from maestro.trainer.common.utils.leaderboard import CheckpointsLeaderboard
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker
from maestro.trainer.common.utils.metrics_tracing import MetricsTracker, \
save_metric_plots
from maestro.trainer.common.utils.reproducibility import make_it_reproducible
from maestro.trainer.models.florence_2.data_loading import prepare_data_loaders
from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary, summarise_training_metrics
from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary
from maestro.trainer.models.paligemma.training import LoraInitLiteral


Expand Down Expand Up @@ -101,6 +103,7 @@ def train(configuration: TrainingConfiguration) -> None:
training_metrics_tracker=training_metrics_tracker,
validation_metrics_tracker=validation_metrics_tracker,
)

best_model_path = checkpoints_leaderboard.get_best_model()
print(f"Loading best model from {best_model_path}")
processor, model = load_model(
Expand All @@ -119,11 +122,18 @@ def train(configuration: TrainingConfiguration) -> None:
print(f"Saving best model: {best_model_dir}")
model.save_pretrained(best_model_dir)
processor.save_pretrained(best_model_dir)
summarise_training_metrics(
training_metrics_tracker=training_metrics_tracker,
validation_metrics_tracker=validation_metrics_tracker,
training_dir=configuration.training_dir,
save_metric_plots(
training_tracker=training_metrics_tracker,
validation_tracker=validation_metrics_tracker,
output_dir=os.path.join(configuration.training_dir, "metrics"),
)
training_metrics_tracker.as_json(
output_dir=os.path.join(configuration.training_dir, "metrics"),
filename="training.json")
validation_metrics_tracker.as_json(
output_dir=os.path.join(configuration.training_dir, "metrics"),
filename="validation.json")

for split_name in ["valid", "test"]:
prepare_detection_training_summary(
processor=processor,
Expand Down Expand Up @@ -304,7 +314,7 @@ def run_validation_epoch(
val_loss = 0.0
epoch_marker = ""
if epoch_number is not None:
epoch_marker = f"| Epoch {epoch_number + 1}/{configuration.training_epochs}"
epoch_marker = f"| Epoch {epoch_number}/{configuration.training_epochs}"
with torch.no_grad():
for inputs, answers in tqdm(loader, desc=f"{title} {epoch_marker}"):
input_ids = inputs["input_ids"]
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
supervision~=0.22.0
supervision~=0.24.0rc1
requests>=2.31.0,<=2.32.3
transformers~=4.44.2
torch~=2.4.0
accelerate~=0.33.0
sentencepiece~=0.2.0
peft~=0.12.0
flash-attn~=2.6.3
flash-attn~=2.6.3 # does not work on mac
einops~=0.8.0
timm~=1.0.9

0 comments on commit e0eca6b

Please sign in to comment.