diff --git a/maestro/trainer/common/utils/metrics_tracing.py b/maestro/trainer/common/utils/metrics_tracing.py index 6c5b6ae..d99bbbe 100644 --- a/maestro/trainer/common/utils/metrics_tracing.py +++ b/maestro/trainer/common/utils/metrics_tracing.py @@ -1,7 +1,12 @@ from __future__ import annotations +import json +import os +from collections import defaultdict from typing import Dict, Tuple, List +import matplotlib.pyplot as plt + class MetricsTracker: @@ -26,3 +31,95 @@ def get_metric_values( if with_index: return self._metrics[metric] return [value[2] for value in self._metrics[metric]] + + def as_json( + self, + output_dir: str = None, + filename: str = None + ) -> Dict[str, List[Dict[str, float]]]: + metrics_data = {} + for metric, values in self._metrics.items(): + metrics_data[metric] = [ + {'epoch': epoch, 'step': step, 'value': value} + for epoch, step, value + in values + ] + + if output_dir and filename: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + filepath = os.path.join(output_dir, filename) + with open(filepath, 'w') as file: + json.dump(metrics_data, file, indent=4) + + return metrics_data + + +def aggregate_by_epoch(metric_values: List[Tuple[int, int, float]]) -> Dict[int, float]: + epoch_data = defaultdict(list) + for epoch, step, value in metric_values: + epoch_data[epoch].append(value) + avg_per_epoch = { + epoch: sum(values) / len(values) + for epoch, values + in epoch_data.items() + } + return avg_per_epoch + + +def save_metric_plots( + training_tracker: MetricsTracker, + validation_tracker: MetricsTracker, + output_dir: str +): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + training_metrics = training_tracker.describe_metrics() + validation_metrics = validation_tracker.describe_metrics() + all_metrics = set(training_metrics + validation_metrics) + + for metric in all_metrics: + plt.figure(figsize=(8, 6)) + + if metric in training_metrics: + training_values = training_tracker.get_metric_values( + metric=metric, with_index=True) + training_avg_values = aggregate_by_epoch(training_values) + training_epochs = sorted(training_avg_values.keys()) + training_vals = [training_avg_values[epoch] for epoch in training_epochs] + plt.plot( + training_epochs, + training_vals, + label=f'Training {metric}', + marker='o', + linestyle='-', + color='blue' + ) + + if metric in validation_metrics: + validation_values = validation_tracker.get_metric_values( + metric=metric, with_index=True) + validation_avg_values = aggregate_by_epoch(validation_values) + validation_epochs = sorted(validation_avg_values.keys()) + validation_vals = [ + validation_avg_values[epoch] + for epoch + in validation_epochs + ] + plt.plot( + validation_epochs, + validation_vals, + label=f'Validation {metric}', + marker='o', + linestyle='--', + color='orange' + ) + + plt.title(f'{metric.capitalize()} over Epochs') + plt.xlabel('Epoch') + plt.ylabel(f'{metric.capitalize()} Value') + plt.legend() + plt.grid(True) + plt.savefig(f'{output_dir}/{metric}_plot.png') + plt.close() diff --git a/maestro/trainer/models/florence_2/metrics.py b/maestro/trainer/models/florence_2/metrics.py index 5746665..3ed7aab 100644 --- a/maestro/trainer/models/florence_2/metrics.py +++ b/maestro/trainer/models/florence_2/metrics.py @@ -3,7 +3,6 @@ import re from typing import List, Tuple -import matplotlib.pyplot as plt import cv2 import numpy as np import torch @@ -13,7 +12,6 @@ from maestro.trainer.common.data_loaders.datasets import DetectionDataset from maestro.trainer.common.utils.file_system import save_json -from maestro.trainer.common.utils.metrics_tracing import MetricsTracker from maestro.trainer.models.florence_2.data_loading import prepare_detection_dataset @@ -188,44 +186,3 @@ def dump_visualised_samples( concatenated = cv2.hconcat([target_image, prediction_image]) target_image_path = os.path.join(target_dir, image_name) cv2.imwrite(target_image_path, concatenated) - - -def summarise_training_metrics( - training_metrics_tracker: MetricsTracker, - validation_metrics_tracker: MetricsTracker, - training_dir: str, -) -> None: - summarise_metrics(metrics_tracker=training_metrics_tracker, training_dir=training_dir, split_name="train") - summarise_metrics(metrics_tracker=validation_metrics_tracker, training_dir=training_dir, split_name="valid") - - -def summarise_metrics( - metrics_tracker: MetricsTracker, - training_dir: str, - split_name: str, -) -> None: - plots_dir_path = os.path.join(training_dir, "metrics", split_name) - os.makedirs(plots_dir_path, exist_ok=True) - for metric_name in metrics_tracker.describe_metrics(): - plot_path = os.path.join(plots_dir_path, f"metric_{metric_name}_plot.png") - plt.clf() - metric_values_with_index = metrics_tracker.get_metric_values( - metric=metric_name, - with_index=True, - ) - xs = np.arange(0, len(metric_values_with_index)) - xticks_xs, xticks_labels = [], [] - previous = None - for v, x in zip(metric_values_with_index, xs): - if v[0] != previous: - xticks_xs.append(x) - xticks_labels.append(v[0]) - previous = v[0] - ys = [e[2] for e in metric_values_with_index] - plt.scatter(xs, ys, marker="x") - plt.plot(xs, ys, linestyle="dashed", linewidth=0.3) - plt.title(f"Value of {metric_name} for {split_name} set") - plt.xticks(xticks_xs, labels=xticks_labels) - plt.xlabel("Epochs") - plt.savefig(plot_path, dpi=120) - plt.clf() diff --git a/maestro/trainer/models/florence_2/training.py b/maestro/trainer/models/florence_2/training.py index 4a0cd5e..aaa1687 100644 --- a/maestro/trainer/models/florence_2/training.py +++ b/maestro/trainer/models/florence_2/training.py @@ -12,12 +12,14 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoProcessor, get_scheduler -from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, DEFAULT_CUDA_DEVICE +from maestro.trainer.common.configuration.env import CUDA_DEVICE_ENV, \ + DEFAULT_CUDA_DEVICE from maestro.trainer.common.utils.leaderboard import CheckpointsLeaderboard -from maestro.trainer.common.utils.metrics_tracing import MetricsTracker +from maestro.trainer.common.utils.metrics_tracing import MetricsTracker, \ + save_metric_plots from maestro.trainer.common.utils.reproducibility import make_it_reproducible from maestro.trainer.models.florence_2.data_loading import prepare_data_loaders -from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary, summarise_training_metrics +from maestro.trainer.models.florence_2.metrics import prepare_detection_training_summary from maestro.trainer.models.paligemma.training import LoraInitLiteral @@ -101,6 +103,7 @@ def train(configuration: TrainingConfiguration) -> None: training_metrics_tracker=training_metrics_tracker, validation_metrics_tracker=validation_metrics_tracker, ) + best_model_path = checkpoints_leaderboard.get_best_model() print(f"Loading best model from {best_model_path}") processor, model = load_model( @@ -119,11 +122,18 @@ def train(configuration: TrainingConfiguration) -> None: print(f"Saving best model: {best_model_dir}") model.save_pretrained(best_model_dir) processor.save_pretrained(best_model_dir) - summarise_training_metrics( - training_metrics_tracker=training_metrics_tracker, - validation_metrics_tracker=validation_metrics_tracker, - training_dir=configuration.training_dir, + save_metric_plots( + training_tracker=training_metrics_tracker, + validation_tracker=validation_metrics_tracker, + output_dir=os.path.join(configuration.training_dir, "metrics"), ) + training_metrics_tracker.as_json( + output_dir=os.path.join(configuration.training_dir, "metrics"), + filename="training.json") + validation_metrics_tracker.as_json( + output_dir=os.path.join(configuration.training_dir, "metrics"), + filename="validation.json") + for split_name in ["valid", "test"]: prepare_detection_training_summary( processor=processor, @@ -304,7 +314,7 @@ def run_validation_epoch( val_loss = 0.0 epoch_marker = "" if epoch_number is not None: - epoch_marker = f"| Epoch {epoch_number + 1}/{configuration.training_epochs}" + epoch_marker = f"| Epoch {epoch_number}/{configuration.training_epochs}" with torch.no_grad(): for inputs, answers in tqdm(loader, desc=f"{title} {epoch_marker}"): input_ids = inputs["input_ids"] diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c10deb3..a181a1a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,10 +1,10 @@ -supervision~=0.22.0 +supervision~=0.24.0rc1 requests>=2.31.0,<=2.32.3 transformers~=4.44.2 torch~=2.4.0 accelerate~=0.33.0 sentencepiece~=0.2.0 peft~=0.12.0 -flash-attn~=2.6.3 +flash-attn~=2.6.3 # does not work on mac einops~=0.8.0 timm~=1.0.9 \ No newline at end of file