From 9f6d990904e6d6ce20e03da0f6b5d3c828238652 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Tue, 17 Sep 2024 22:10:12 +0300 Subject: [PATCH] =?UTF-8?q?refactor:=20=F0=9F=A7=B9=20update=20type=20hint?= =?UTF-8?q?s=20and=20clean=20up=20docstrings=20across=20multiple=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../grounding_dino_and_gpt4_vision.ipynb | 4 +- maestro/cli/main.py | 4 +- maestro/lmms/gpt4.py | 6 +- maestro/markers/sam.py | 8 +-- maestro/postprocessing/mask.py | 22 +++---- maestro/postprocessing/text.py | 11 ++-- maestro/primitives.py | 4 +- .../trainer/common/data_loaders/datasets.py | 17 ++--- maestro/trainer/common/data_loaders/jsonl.py | 7 +-- maestro/trainer/common/utils/file_system.py | 11 ++-- maestro/trainer/common/utils/leaderboard.py | 8 +-- maestro/trainer/common/utils/metrics.py | 62 ++++++++----------- .../trainer/models/florence_2/checkpoints.py | 10 +-- maestro/trainer/models/florence_2/core.py | 8 +-- .../trainer/models/florence_2/data_loading.py | 8 +-- .../trainer/models/florence_2/entrypoint.py | 10 +-- maestro/trainer/models/florence_2/metrics.py | 15 +++-- maestro/trainer/models/paligemma/training.py | 7 ++- maestro/visualizers.py | 6 +- pyproject.toml | 11 ++-- test/test_postprocess.py | 4 +- 21 files changed, 104 insertions(+), 139 deletions(-) diff --git a/cookbooks/grounding_dino_and_gpt4_vision.ipynb b/cookbooks/grounding_dino_and_gpt4_vision.ipynb index 4ef3ad8..d7150fd 100644 --- a/cookbooks/grounding_dino_and_gpt4_vision.ipynb +++ b/cookbooks/grounding_dino_and_gpt4_vision.ipynb @@ -470,8 +470,6 @@ }, "outputs": [], "source": [ - "from typing import List\n", - "\n", "import cv2\n", "import numpy as np\n", "import supervision as sv\n", @@ -486,7 +484,7 @@ " return sv.Detections(xyxy=xyxy)\n", "\n", "\n", - "def annotate(image_source: np.ndarray, detections: sv.Detections, labels: List[str] = None) -> np.ndarray:\n", + "def annotate(image_source: np.ndarray, detections: sv.Detections, labels: list[str] = None) -> np.ndarray:\n", " box_annotator = sv.BoxAnnotator()\n", " annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)\n", " annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)\n", diff --git a/maestro/cli/main.py b/maestro/cli/main.py index 2a2df2e..94559a8 100644 --- a/maestro/cli/main.py +++ b/maestro/cli/main.py @@ -8,12 +8,12 @@ @app.command(help="Display information about maestro") -def info(): +def info() -> None: typer.echo("Welcome to maestro CLI. Let's train some VLM! 🏋") @app.command(help="Display version of maestro") -def version(): +def version() -> None: typer.echo(f"Maestro version: {__version__}") diff --git a/maestro/lmms/gpt4.py b/maestro/lmms/gpt4.py index 58485ec..036e995 100644 --- a/maestro/lmms/gpt4.py +++ b/maestro/lmms/gpt4.py @@ -15,8 +15,7 @@ def encode_image_to_base64(image: np.ndarray) -> str: - """ - Encodes an image into a base64-encoded string in JPEG format. + """Encodes an image into a base64-encoded string in JPEG format. Parameters: image (np.ndarray): The image to be encoded. This should be a numpy array as @@ -56,8 +55,7 @@ def compose_payload(image: np.ndarray, prompt: str) -> dict: def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str: - """ - Sends an image and a textual prompt to the OpenAI API and returns the API's textual + """Sends an image and a textual prompt to the OpenAI API and returns the API's textual response. This function integrates an image with a user-defined prompt to generate a response diff --git a/maestro/markers/sam.py b/maestro/markers/sam.py index 7b7afdd..868d4c5 100644 --- a/maestro/markers/sam.py +++ b/maestro/markers/sam.py @@ -10,8 +10,7 @@ class SegmentAnythingMarkGenerator: - """ - A class for performing image segmentation using a specified model. + """A class for performing image segmentation using a specified model. Parameters: device (str): The device to run the model on (e.g., 'cpu', 'cuda'). @@ -19,7 +18,7 @@ class SegmentAnythingMarkGenerator: 'facebook/sam-vit-huge'. """ - def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge"): + def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge") -> None: self.model = SamModel.from_pretrained(model_name).to(device) self.processor = SamProcessor.from_pretrained(model_name) self.image_processor = SamImageProcessor.from_pretrained(model_name) @@ -29,8 +28,7 @@ def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge ) def generate(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> sv.Detections: - """ - Generate image segmentation marks. + """Generate image segmentation marks. Parameters: image (np.ndarray): The image to be marked in BGR format. diff --git a/maestro/postprocessing/mask.py b/maestro/postprocessing/mask.py index ef60a38..b19f593 100644 --- a/maestro/postprocessing/mask.py +++ b/maestro/postprocessing/mask.py @@ -6,8 +6,7 @@ class FeatureType(Enum): - """ - An enumeration to represent the types of features for mask adjustment in image + """An enumeration to represent the types of features for mask adjustment in image segmentation. """ @@ -20,8 +19,7 @@ def list(cls): def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray: - """ - Vectorized computation of the Intersection over Union (IoU) for all pairs of masks. + """Vectorized computation of the Intersection over Union (IoU) for all pairs of masks. Parameters: masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the @@ -49,8 +47,7 @@ def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray: def mask_non_max_suppression(masks: np.ndarray, iou_threshold: float = 0.6) -> np.ndarray: - """ - Performs Non-Max Suppression on a set of masks by prioritizing larger masks and + """Performs Non-Max Suppression on a set of masks by prioritizing larger masks and removing smaller masks that overlap significantly. When the IoU between two masks exceeds the specified threshold, the smaller mask @@ -85,8 +82,7 @@ def mask_non_max_suppression(masks: np.ndarray, iou_threshold: float = 0.6) -> n def filter_masks_by_relative_area( masks: np.ndarray, minimum_area: float = 0.01, maximum_area: float = 1.0 ) -> np.ndarray: - """ - Filters masks based on their relative area within the total area of each mask. + """Filters masks based on their relative area within the total area of each mask. Parameters: masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the @@ -104,7 +100,6 @@ def filter_masks_by_relative_area( ValueError: If `minimum_area` or `maximum_area` are outside the `0` to `1` range, or if `minimum_area` is greater than `maximum_area`. """ - if not (isinstance(masks, np.ndarray) and masks.ndim == 3): raise ValueError("Input must be a 3D numpy array.") @@ -122,8 +117,7 @@ def filter_masks_by_relative_area( def adjust_mask_features_by_relative_area( mask: np.ndarray, area_threshold: float, feature_type: FeatureType = FeatureType.ISLAND ) -> np.ndarray: - """ - Adjusts a mask by removing small islands or filling small holes based on a relative + """Adjusts a mask by removing small islands or filling small holes based on a relative area threshold. !!! warning @@ -162,8 +156,7 @@ def adjust_mask_features_by_relative_area( def masks_to_marks(masks: np.ndarray) -> sv.Detections: - """ - Converts a set of masks to a marks (sv.Detections) object. + """Converts a set of masks to a marks (sv.Detections) object. Parameters: masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the @@ -187,8 +180,7 @@ def refine_marks( minimum_mask_area: float = 0.02, maximum_mask_area: float = 1.0, ) -> sv.Detections: - """ - Refines a set of masks by removing small islands and holes, and filtering by mask + """Refines a set of masks by removing small islands and holes, and filtering by mask area. Parameters: diff --git a/maestro/postprocessing/text.py b/maestro/postprocessing/text.py index af313b4..2eb4ee9 100644 --- a/maestro/postprocessing/text.py +++ b/maestro/postprocessing/text.py @@ -1,5 +1,4 @@ import re -from typing import Dict, List import numpy as np import supervision as sv @@ -7,9 +6,8 @@ from maestro.primitives import MarkMode -def extract_marks_in_brackets(text: str, mode: MarkMode) -> List[str]: - """ - Extracts all unique marks enclosed in square brackets from a given string, based +def extract_marks_in_brackets(text: str, mode: MarkMode) -> list[str]: + """Extracts all unique marks enclosed in square brackets from a given string, based on the specified mode. Duplicates are removed and the results are sorted in descending order. @@ -38,9 +36,8 @@ def extract_marks_in_brackets(text: str, mode: MarkMode) -> List[str]: return sorted(unique_marks, reverse=False) -def extract_relevant_masks(text: str, detections: sv.Detections) -> Dict[str, np.ndarray]: - """ - Extracts relevant masks from the detections based on marks found in the given text. +def extract_relevant_masks(text: str, detections: sv.Detections) -> dict[str, np.ndarray]: + """Extracts relevant masks from the detections based on marks found in the given text. Args: text (str): The string containing marks in square brackets to be searched for. diff --git a/maestro/primitives.py b/maestro/primitives.py index 81f1b84..5005263 100644 --- a/maestro/primitives.py +++ b/maestro/primitives.py @@ -2,9 +2,7 @@ class MarkMode(Enum): - """ - An enumeration for different marking modes. - """ + """An enumeration for different marking modes.""" NUMERIC = "NUMERIC" ALPHABETIC = "ALPHABETIC" diff --git a/maestro/trainer/common/data_loaders/datasets.py b/maestro/trainer/common/data_loaders/datasets.py index 04c1c91..62bd05e 100644 --- a/maestro/trainer/common/data_loaders/datasets.py +++ b/maestro/trainer/common/data_loaders/datasets.py @@ -1,20 +1,20 @@ import json import os -from typing import Any, Dict, List, Tuple +from typing import Any from PIL import Image from transformers.pipelines.base import Dataset class JSONLDataset: - def __init__(self, jsonl_file_path: str, image_directory_path: str): + def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None: self.jsonl_file_path = jsonl_file_path self.image_directory_path = image_directory_path self.entries = self._load_entries() - def _load_entries(self) -> List[Dict[str, Any]]: + def _load_entries(self) -> list[dict[str, Any]]: entries = [] - with open(self.jsonl_file_path, "r") as file: + with open(self.jsonl_file_path) as file: for line in file: data = json.loads(line) entries.append(data) @@ -23,7 +23,7 @@ def _load_entries(self) -> List[Dict[str, Any]]: def __len__(self) -> int: return len(self.entries) - def __getitem__(self, idx: int) -> Tuple[Image.Image, Dict[str, Any]]: + def __getitem__(self, idx: int) -> tuple[Image.Image, dict[str, Any]]: if idx < 0 or idx >= len(self.entries): raise IndexError("Index out of range") @@ -31,16 +31,17 @@ def __getitem__(self, idx: int) -> Tuple[Image.Image, Dict[str, Any]]: image_path = os.path.join(self.image_directory_path, entry["image"]) try: image = Image.open(image_path) - return (image, entry) except FileNotFoundError: raise FileNotFoundError(f"Image file {image_path} not found.") + else: + return (image, entry) class DetectionDataset(Dataset): - def __init__(self, jsonl_file_path: str, image_directory_path: str): + def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None: self.dataset = JSONLDataset(jsonl_file_path, image_directory_path) - def __len__(self): + def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx): diff --git a/maestro/trainer/common/data_loaders/jsonl.py b/maestro/trainer/common/data_loaders/jsonl.py index 3630e11..dc817b3 100644 --- a/maestro/trainer/common/data_loaders/jsonl.py +++ b/maestro/trainer/common/data_loaders/jsonl.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from typing import List from torch.utils.data import Dataset @@ -18,14 +17,14 @@ def from_jsonl_file(cls, path: str) -> JSONLDataset: random.shuffle(file_content) return cls(jsons=file_content) - def __init__(self, jsons: List[dict]): + def __init__(self, jsons: list[dict]) -> None: self.jsons = jsons def __getitem__(self, index): return self.jsons[index] - def __len__(self): + def __len__(self) -> int: return len(self.jsons) - def shuffle(self): + def shuffle(self) -> None: random.shuffle(self.jsons) diff --git a/maestro/trainer/common/utils/file_system.py b/maestro/trainer/common/utils/file_system.py index ce61f03..c2361f9 100644 --- a/maestro/trainer/common/utils/file_system.py +++ b/maestro/trainer/common/utils/file_system.py @@ -1,10 +1,10 @@ import json import os from glob import glob -from typing import List, Union +from typing import Union -def read_jsonl(path: str) -> List[dict]: +def read_jsonl(path: str) -> list[dict]: file_lines = read_file( path=path, split_lines=True, @@ -17,8 +17,8 @@ def read_file( split_lines: bool = False, strip_white_spaces: bool = False, line_separator: str = "\n", -) -> Union[str, List[str]]: - with open(path, "r") as f: +) -> Union[str, list[str]]: + with open(path) as f: file_content = f.read() if strip_white_spaces: file_content = file_content.strip() @@ -42,8 +42,7 @@ def ensure_parent_dir_exists(path: str) -> None: def create_new_run_directory(base_output_dir: str) -> str: - """ - Creates a new numbered directory for the current training run. + """Creates a new numbered directory for the current training run. Args: base_output_dir (str): The base directory where all run directories are stored. diff --git a/maestro/trainer/common/utils/leaderboard.py b/maestro/trainer/common/utils/leaderboard.py index a09d725..6cbcddb 100644 --- a/maestro/trainer/common/utils/leaderboard.py +++ b/maestro/trainer/common/utils/leaderboard.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional, Tuple +from typing import Optional class CheckpointsLeaderboard: def __init__( self, max_checkpoints: int, - ): + ) -> None: self._max_checkpoints = max(max_checkpoints, 1) - self._leaderboard: Dict[int, Tuple[str, float]] = {} + self._leaderboard: dict[int, tuple[str, float]] = {} - def register_checkpoint(self, epoch: int, path: str, loss: float) -> Tuple[bool, Optional[str]]: + def register_checkpoint(self, epoch: int, path: str, loss: float) -> tuple[bool, Optional[str]]: if len(self._leaderboard) < self._max_checkpoints: self._leaderboard[epoch] = (path, loss) return True, None diff --git a/maestro/trainer/common/utils/metrics.py b/maestro/trainer/common/utils/metrics.py index 19a257e..d855199 100644 --- a/maestro/trainer/common/utils/metrics.py +++ b/maestro/trainer/common/utils/metrics.py @@ -7,7 +7,7 @@ import os from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import matplotlib.pyplot as plt import supervision as sv @@ -16,15 +16,13 @@ class BaseMetric(ABC): - """ - Abstract base class for custom metrics. Subclasses must implement + """Abstract base class for custom metrics. Subclasses must implement the 'describe' and 'compute' methods. """ @abstractmethod - def describe(self) -> List[str]: - """ - Describe the names of the metrics that this class will compute. + def describe(self) -> list[str]: + """Describe the names of the metrics that this class will compute. Returns: List[str]: A list of metric names that will be computed. @@ -32,9 +30,8 @@ def describe(self) -> List[str]: pass @abstractmethod - def compute(self, targets: List[Any], predictions: List[Any]) -> Dict[str, float]: - """ - Compute the metric based on the targets and predictions. + def compute(self, targets: list[Any], predictions: list[Any]) -> dict[str, float]: + """Compute the metric based on the targets and predictions. Args: targets (List[Any]): The ground truth. @@ -48,22 +45,18 @@ def compute(self, targets: List[Any], predictions: List[Any]) -> Dict[str, float class MeanAveragePrecisionMetric(BaseMetric): - """ - A class used to compute the Mean Average Precision (mAP) metric. - """ + """A class used to compute the Mean Average Precision (mAP) metric.""" - def describe(self) -> List[str]: - """ - Returns a list of metric names that this class will compute. + def describe(self) -> list[str]: + """Returns a list of metric names that this class will compute. Returns: List[str]: A list of metric names. """ return ["map50:95", "map50", "map75"] - def compute(self, targets: List[sv.Detections], predictions: List[sv.Detections]) -> Dict[str, float]: - """ - Computes the mAP metrics based on the targets and predictions. + def compute(self, targets: list[sv.Detections], predictions: list[sv.Detections]) -> dict[str, float]: + """Computes the mAP metrics based on the targets and predictions. Args: targets (List[sv.Detections]): The ground truth detections. @@ -79,16 +72,16 @@ def compute(self, targets: List[sv.Detections], predictions: List[sv.Detections] class MetricsTracker: @classmethod - def init(cls, metrics: List[str]) -> MetricsTracker: + def init(cls, metrics: list[str]) -> MetricsTracker: return cls(metrics={metric: [] for metric in metrics}) - def __init__(self, metrics: Dict[str, List[Tuple[int, int, float]]]): + def __init__(self, metrics: dict[str, list[tuple[int, int, float]]]) -> None: self._metrics = metrics def register(self, metric: str, epoch: int, step: int, value: float) -> None: self._metrics[metric].append((epoch, step, value)) - def describe_metrics(self) -> List[str]: + def describe_metrics(self) -> list[str]: return list(self._metrics.keys()) def get_metric_values( @@ -102,7 +95,7 @@ def get_metric_values( def as_json( self, output_dir: Optional[str] = None, filename: Optional[str] = None - ) -> Dict[str, List[Dict[str, float]]]: + ) -> dict[str, list[dict[str, float]]]: metrics_data = {} for metric, values in self._metrics.items(): metrics_data[metric] = [{"epoch": epoch, "step": step, "value": value} for epoch, step, value in values] @@ -117,9 +110,8 @@ def as_json( return metrics_data -def aggregate_by_epoch(metric_values: List[Tuple[int, int, float]]) -> Dict[int, float]: - """ - Aggregates metric values by epoch, calculating the average for each epoch. +def aggregate_by_epoch(metric_values: list[tuple[int, int, float]]) -> dict[int, float]: + """Aggregates metric values by epoch, calculating the average for each epoch. Args: metric_values (List[Tuple[int, int, float]]): A list of tuples containing @@ -135,9 +127,8 @@ def aggregate_by_epoch(metric_values: List[Tuple[int, int, float]]) -> Dict[int, return avg_per_epoch -def save_metric_plots(training_tracker: MetricsTracker, validation_tracker: MetricsTracker, output_dir: str): - """ - Saves plots of training and validation metrics over epochs. +def save_metric_plots(training_tracker: MetricsTracker, validation_tracker: MetricsTracker, output_dir: str) -> None: + """Saves plots of training and validation metrics over epochs. Args: training_tracker (MetricsTracker): Tracker containing training metrics. @@ -190,10 +181,9 @@ def save_metric_plots(training_tracker: MetricsTracker, validation_tracker: Metr def display_results( - prompts: List[str], expected_responses: List[str], generated_texts: List[str], images: List[Image.Image] + prompts: list[str], expected_responses: list[str], generated_texts: list[str], images: list[Image.Image] ) -> None: - """ - Display the results of model inference in IPython environments. + """Display the results of model inference in IPython environments. This function attempts to display the results (prompts, expected responses, generated texts, and images) in an HTML format if running in an IPython @@ -221,10 +211,9 @@ def display_results( def create_html_output( - prompts: List[str], expected_responses: List[str], generated_texts: List[str], images: List[Image.Image] + prompts: list[str], expected_responses: list[str], generated_texts: list[str], images: list[Image.Image] ) -> str: - """ - Create an HTML string to display the results of model inference. + """Create an HTML string to display the results of model inference. This function generates an HTML string that includes styled divs for each result, containing the input image, prompt, expected response, and generated text. @@ -257,9 +246,8 @@ def create_html_output( return html_out -def render_inline(image: Image.Image, resize: Tuple[int, int] = (256, 256)) -> str: - """ - Convert an image into an inline HTML string. +def render_inline(image: Image.Image, resize: tuple[int, int] = (256, 256)) -> str: + """Convert an image into an inline HTML string. This function takes an image, resizes it, and converts it to a base64-encoded string that can be used as the source for an HTML img tag. diff --git a/maestro/trainer/models/florence_2/checkpoints.py b/maestro/trainer/models/florence_2/checkpoints.py index a5d5eb5..28f44f9 100644 --- a/maestro/trainer/models/florence_2/checkpoints.py +++ b/maestro/trainer/models/florence_2/checkpoints.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Tuple +from typing import Optional import torch from transformers import AutoModelForCausalLM, AutoProcessor @@ -23,7 +23,7 @@ class CheckpointManager: best_checkpoint_dir (str): Directory for the best checkpoint. """ - def __init__(self, training_dir: str): + def __init__(self, training_dir: str) -> None: """Initializes the CheckpointManager. Args: @@ -34,7 +34,7 @@ def __init__(self, training_dir: str): self.latest_checkpoint_dir = os.path.join(training_dir, "checkpoints", "latest") self.best_checkpoint_dir = os.path.join(training_dir, "checkpoints", "best") - def save_latest(self, processor: AutoProcessor, model: AutoModelForCausalLM): + def save_latest(self, processor: AutoProcessor, model: AutoModelForCausalLM) -> None: """Saves the latest model checkpoint. Args: @@ -43,7 +43,7 @@ def save_latest(self, processor: AutoProcessor, model: AutoModelForCausalLM): """ save_model(self.latest_checkpoint_dir, processor, model) - def save_best(self, processor: AutoProcessor, model: AutoModelForCausalLM, val_loss: float): + def save_best(self, processor: AutoProcessor, model: AutoModelForCausalLM, val_loss: float) -> None: """Saves the best model checkpoint if the validation loss improves. Args: @@ -87,7 +87,7 @@ def load_model( revision: str = DEFAULT_FLORENCE2_MODEL_REVISION, device: torch.device = DEVICE, cache_dir: Optional[str] = None, -) -> Tuple[AutoProcessor, AutoModelForCausalLM]: +) -> tuple[AutoProcessor, AutoModelForCausalLM]: """Loads a Florence-2 model and its associated processor. Args: diff --git a/maestro/trainer/models/florence_2/core.py b/maestro/trainer/models/florence_2/core.py index 5b3d6e0..4b6e871 100644 --- a/maestro/trainer/models/florence_2/core.py +++ b/maestro/trainer/models/florence_2/core.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass, field, replace -from typing import List, Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import torch from peft import LoraConfig, PeftModel, get_peft_model @@ -89,7 +89,7 @@ class TrainingConfiguration: use_rslora: bool = True init_lora_weights: Union[bool, LoraInitLiteral] = "gaussian" output_dir: str = "./training/florence-2" - metrics: List[BaseMetric] = field(default_factory=list) + metrics: list[BaseMetric] = field(default_factory=list) def train(config: TrainingConfiguration) -> None: @@ -189,7 +189,7 @@ def prepare_peft_model( def run_training_loop( processor: AutoProcessor, model: PeftModel, - data_loaders: Tuple[DataLoader, Optional[DataLoader]], + data_loaders: tuple[DataLoader, Optional[DataLoader]], config: TrainingConfiguration, training_metrics_tracker: MetricsTracker, validation_metrics_tracker: MetricsTracker, @@ -234,7 +234,7 @@ def run_training_epoch( checkpoint_manager: CheckpointManager, ) -> None: model.train() - training_losses: List[float] = [] + training_losses: list[float] = [] with tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{config.epochs}", unit="batch") as pbar: for step_id, (inputs, answers) in enumerate(train_loader): diff --git a/maestro/trainer/models/florence_2/data_loading.py b/maestro/trainer/models/florence_2/data_loading.py index 62fc28e..5794d6e 100644 --- a/maestro/trainer/models/florence_2/data_loading.py +++ b/maestro/trainer/models/florence_2/data_loading.py @@ -1,7 +1,7 @@ import logging import os from functools import partial -from typing import List, Optional, Tuple +from typing import Optional import torch from PIL import Image @@ -19,7 +19,7 @@ def prepare_data_loaders( num_workers: int = 0, test_batch_size: Optional[int] = None, test_loaders_workers: Optional[int] = None, -) -> Tuple[ +) -> tuple[ DataLoader, Optional[DataLoader], Optional[DataLoader], @@ -101,10 +101,10 @@ def prepare_detection_dataset( def collate_fn( - batch: Tuple[List[str], List[str], List[Image.Image]], + batch: tuple[list[str], list[str], list[Image.Image]], processor: AutoProcessor, device: torch.device, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: questions, answers, images = zip(*batch) inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device) return inputs, answers diff --git a/maestro/trainer/models/florence_2/entrypoint.py b/maestro/trainer/models/florence_2/entrypoint.py index 37a74ae..c11441c 100644 --- a/maestro/trainer/models/florence_2/entrypoint.py +++ b/maestro/trainer/models/florence_2/entrypoint.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Annotated, Dict, List, Literal, Optional, Type, Union +from typing import Annotated, Literal, Optional, Union import rich import torch @@ -18,12 +18,12 @@ florence_2_app = typer.Typer(help="Fine-tune and evaluate Florence 2 model") -METRIC_CLASSES: Dict[str, Type[BaseMetric]] = { +METRIC_CLASSES: dict[str, type[BaseMetric]] = { "mean_average_precision": MeanAveragePrecisionMetric, } -def parse_metrics(metrics: List[str]) -> List[BaseMetric]: +def parse_metrics(metrics: list[str]) -> list[BaseMetric]: metric_objects = [] for metric_name in metrics: metric_class = METRIC_CLASSES.get(metric_name.lower()) @@ -119,7 +119,7 @@ def train( typer.Option("--output_dir", help="Directory to save output files"), ] = "./training/florence-2", metrics: Annotated[ - List[str], + list[str], typer.Option("--metrics", help="List of metrics to track during training"), ] = [], ) -> None: @@ -191,7 +191,7 @@ def evaluate( typer.Option("--output_dir", help="Directory to save output files"), ] = "./evaluation/florence-2", metrics: Annotated[ - List[str], + list[str], typer.Option("--metrics", help="List of metrics to track during evaluation"), ] = [], ) -> None: diff --git a/maestro/trainer/models/florence_2/metrics.py b/maestro/trainer/models/florence_2/metrics.py index 0f1d51a..be805db 100644 --- a/maestro/trainer/models/florence_2/metrics.py +++ b/maestro/trainer/models/florence_2/metrics.py @@ -1,5 +1,4 @@ import re -from typing import List, Tuple import numpy as np import supervision as sv @@ -14,12 +13,12 @@ def postprocess_florence2_output_for_mean_average_precision( - expected_responses: List[str], - generated_texts: List[str], - images: List[Image.Image], - classes: List[str], + expected_responses: list[str], + generated_texts: list[str], + images: list[Image.Image], + classes: list[str], processor: AutoProcessor, -) -> Tuple[List[sv.Detections], List[sv.Detections]]: +) -> tuple[list[sv.Detections], list[sv.Detections]]: targets = [] predictions = [] @@ -52,7 +51,7 @@ def run_predictions( processor: AutoProcessor, model: AutoModelForCausalLM, device: torch.device, -) -> Tuple[List[str], List[str], List[str], List[Image.Image]]: +) -> tuple[list[str], list[str], list[str], list[Image.Image]]: prompts = [] expected_responses = [] generated_texts = [] @@ -77,7 +76,7 @@ def run_predictions( return prompts, expected_responses, generated_texts, images -def extract_unique_detection_dataset_classes(dataset: DetectionDataset) -> List[str]: +def extract_unique_detection_dataset_classes(dataset: DetectionDataset) -> list[str]: class_set = set() for i in range(len(dataset)): _, suffix, _ = dataset[i] diff --git a/maestro/trainer/models/paligemma/training.py b/maestro/trainer/models/paligemma/training.py index ebfd165..23a9af8 100644 --- a/maestro/trainer/models/paligemma/training.py +++ b/maestro/trainer/models/paligemma/training.py @@ -1,5 +1,6 @@ import os -from typing import Iterator, List, Literal, Optional, Tuple, Union +from collections.abc import Iterator +from typing import Literal, Optional, Union import torch from peft import LoraConfig, PeftModel, get_peft_model @@ -149,7 +150,7 @@ def load_model( device: torch.device = DEVICE, hf_token: Optional[str] = None, cache_dir: Optional[str] = None, -) -> Tuple[AutoProcessor, PaliGemmaForConditionalGeneration]: +) -> tuple[AutoProcessor, PaliGemmaForConditionalGeneration]: if hf_token is None: hf_token = os.getenv(HF_TOKEN_ENV) processor = AutoProcessor.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir) @@ -192,7 +193,7 @@ def prepare_peft_model( def _collate_fn( - examples: List[dict], + examples: list[dict], dataset_root: str, processor: AutoProcessor, device: torch.device = DEVICE, diff --git a/maestro/visualizers.py b/maestro/visualizers.py index 7d793ab..d470ed9 100644 --- a/maestro/visualizers.py +++ b/maestro/visualizers.py @@ -3,8 +3,7 @@ class MarkVisualizer: - """ - A class for visualizing different marks including bounding boxes, masks, polygons, + """A class for visualizing different marks including bounding boxes, masks, polygons, and labels. Parameters: @@ -34,8 +33,7 @@ def visualize( with_polygon: bool = True, with_label: bool = True, ) -> np.ndarray: - """ - Visualizes annotations on an image. + """Visualizes annotations on an image. This method takes an image and an instance of sv.Detections, and overlays the specified types of marks (boxes, masks, polygons, labels) on the image. diff --git a/pyproject.toml b/pyproject.toml index 26eb9f1..a4c0417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,10 +62,9 @@ docs = [ ] dev = [ "pytest~=8.3.2", - "black~=24.8.0", "pre-commit~=3.8.0", "mypy~=1.11.2", - "flake8~=7.1.1", + "ruff~=0.6.5", "tox~=4.18.1" ] @@ -126,8 +125,8 @@ indent-width = 4 [tool.ruff.lint] # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -select = ["E", "F", "I", "A", "Q", "W"] -ignore = [] +select = ["E", "F", "I", "A", "Q", "W", "N", "T", "Q","TRY","UP"] +ignore = ["T201","TRY003"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = [ "A", @@ -192,7 +191,9 @@ convention = "google" "__init__.py" = ["E402", "F401"] "*.ipynb" = ["E501"] - +[tool.ruff.lint.pyupgrade] +# Preserve types, even if a file imports `from __future__ import annotations`. +keep-runtime-typing = true [tool.ruff.lint.mccabe] # Flag errors (`C901`) whenever the complexity level exceeds 5. diff --git a/test/test_postprocess.py b/test/test_postprocess.py index f4f6c36..bf23840 100644 --- a/test/test_postprocess.py +++ b/test/test_postprocess.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from maestro.postprocessing.text import extract_marks_in_brackets @@ -23,6 +21,6 @@ ("[1] lorem ipsum [A] dolor sit amet", MarkMode.ALPHABETIC, ["A"]), ], ) -def test_extract_marks_in_brackets(text: str, mode: MarkMode, expected_result: List[str]) -> None: +def test_extract_marks_in_brackets(text: str, mode: MarkMode, expected_result: list[str]) -> None: result = extract_marks_in_brackets(text=text, mode=mode) assert result == expected_result