Skip to content

Commit

Permalink
refactor: šŸ§¹ update type hints and clean up docstrings across multipleā€¦
Browse files Browse the repository at this point in the history
ā€¦ files
  • Loading branch information
onuralpszr committed Sep 17, 2024
1 parent c84221c commit 9f6d990
Show file tree
Hide file tree
Showing 21 changed files with 104 additions and 139 deletions.
4 changes: 1 addition & 3 deletions cookbooks/grounding_dino_and_gpt4_vision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import supervision as sv\n",
Expand All @@ -486,7 +484,7 @@
" return sv.Detections(xyxy=xyxy)\n",
"\n",
"\n",
"def annotate(image_source: np.ndarray, detections: sv.Detections, labels: List[str] = None) -> np.ndarray:\n",
"def annotate(image_source: np.ndarray, detections: sv.Detections, labels: list[str] = None) -> np.ndarray:\n",
" box_annotator = sv.BoxAnnotator()\n",
" annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)\n",
" annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)\n",
Expand Down
4 changes: 2 additions & 2 deletions maestro/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


@app.command(help="Display information about maestro")
def info():
def info() -> None:
typer.echo("Welcome to maestro CLI. Let's train some VLM! šŸ‹")


@app.command(help="Display version of maestro")
def version():
def version() -> None:
typer.echo(f"Maestro version: {__version__}")


Expand Down
6 changes: 2 additions & 4 deletions maestro/lmms/gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@


def encode_image_to_base64(image: np.ndarray) -> str:
"""
Encodes an image into a base64-encoded string in JPEG format.
"""Encodes an image into a base64-encoded string in JPEG format.
Parameters:
image (np.ndarray): The image to be encoded. This should be a numpy array as
Expand Down Expand Up @@ -56,8 +55,7 @@ def compose_payload(image: np.ndarray, prompt: str) -> dict:


def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str:
"""
Sends an image and a textual prompt to the OpenAI API and returns the API's textual
"""Sends an image and a textual prompt to the OpenAI API and returns the API's textual
response.
This function integrates an image with a user-defined prompt to generate a response
Expand Down
8 changes: 3 additions & 5 deletions maestro/markers/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@


class SegmentAnythingMarkGenerator:
"""
A class for performing image segmentation using a specified model.
"""A class for performing image segmentation using a specified model.
Parameters:
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
model_name (str): The name of the model to be loaded. Defaults to
'facebook/sam-vit-huge'.
"""

def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge"):
def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge") -> None:
self.model = SamModel.from_pretrained(model_name).to(device)
self.processor = SamProcessor.from_pretrained(model_name)
self.image_processor = SamImageProcessor.from_pretrained(model_name)
Expand All @@ -29,8 +28,7 @@ def __init__(self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge
)

def generate(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> sv.Detections:
"""
Generate image segmentation marks.
"""Generate image segmentation marks.
Parameters:
image (np.ndarray): The image to be marked in BGR format.
Expand Down
22 changes: 7 additions & 15 deletions maestro/postprocessing/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@


class FeatureType(Enum):
"""
An enumeration to represent the types of features for mask adjustment in image
"""An enumeration to represent the types of features for mask adjustment in image
segmentation.
"""

Expand All @@ -20,8 +19,7 @@ def list(cls):


def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray:
"""
Vectorized computation of the Intersection over Union (IoU) for all pairs of masks.
"""Vectorized computation of the Intersection over Union (IoU) for all pairs of masks.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
Expand Down Expand Up @@ -49,8 +47,7 @@ def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray:


def mask_non_max_suppression(masks: np.ndarray, iou_threshold: float = 0.6) -> np.ndarray:
"""
Performs Non-Max Suppression on a set of masks by prioritizing larger masks and
"""Performs Non-Max Suppression on a set of masks by prioritizing larger masks and
removing smaller masks that overlap significantly.
When the IoU between two masks exceeds the specified threshold, the smaller mask
Expand Down Expand Up @@ -85,8 +82,7 @@ def mask_non_max_suppression(masks: np.ndarray, iou_threshold: float = 0.6) -> n
def filter_masks_by_relative_area(
masks: np.ndarray, minimum_area: float = 0.01, maximum_area: float = 1.0
) -> np.ndarray:
"""
Filters masks based on their relative area within the total area of each mask.
"""Filters masks based on their relative area within the total area of each mask.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
Expand All @@ -104,7 +100,6 @@ def filter_masks_by_relative_area(
ValueError: If `minimum_area` or `maximum_area` are outside the `0` to `1`
range, or if `minimum_area` is greater than `maximum_area`.
"""

if not (isinstance(masks, np.ndarray) and masks.ndim == 3):
raise ValueError("Input must be a 3D numpy array.")

Expand All @@ -122,8 +117,7 @@ def filter_masks_by_relative_area(
def adjust_mask_features_by_relative_area(
mask: np.ndarray, area_threshold: float, feature_type: FeatureType = FeatureType.ISLAND
) -> np.ndarray:
"""
Adjusts a mask by removing small islands or filling small holes based on a relative
"""Adjusts a mask by removing small islands or filling small holes based on a relative
area threshold.
!!! warning
Expand Down Expand Up @@ -162,8 +156,7 @@ def adjust_mask_features_by_relative_area(


def masks_to_marks(masks: np.ndarray) -> sv.Detections:
"""
Converts a set of masks to a marks (sv.Detections) object.
"""Converts a set of masks to a marks (sv.Detections) object.
Parameters:
masks (np.ndarray): A 3D numpy array with shape `(N, H, W)`, where `N` is the
Expand All @@ -187,8 +180,7 @@ def refine_marks(
minimum_mask_area: float = 0.02,
maximum_mask_area: float = 1.0,
) -> sv.Detections:
"""
Refines a set of masks by removing small islands and holes, and filtering by mask
"""Refines a set of masks by removing small islands and holes, and filtering by mask
area.
Parameters:
Expand Down
11 changes: 4 additions & 7 deletions maestro/postprocessing/text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import re
from typing import Dict, List

import numpy as np
import supervision as sv

from maestro.primitives import MarkMode


def extract_marks_in_brackets(text: str, mode: MarkMode) -> List[str]:
"""
Extracts all unique marks enclosed in square brackets from a given string, based
def extract_marks_in_brackets(text: str, mode: MarkMode) -> list[str]:
"""Extracts all unique marks enclosed in square brackets from a given string, based
on the specified mode. Duplicates are removed and the results are sorted in
descending order.
Expand Down Expand Up @@ -38,9 +36,8 @@ def extract_marks_in_brackets(text: str, mode: MarkMode) -> List[str]:
return sorted(unique_marks, reverse=False)


def extract_relevant_masks(text: str, detections: sv.Detections) -> Dict[str, np.ndarray]:
"""
Extracts relevant masks from the detections based on marks found in the given text.
def extract_relevant_masks(text: str, detections: sv.Detections) -> dict[str, np.ndarray]:
"""Extracts relevant masks from the detections based on marks found in the given text.
Args:
text (str): The string containing marks in square brackets to be searched for.
Expand Down
4 changes: 1 addition & 3 deletions maestro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


class MarkMode(Enum):
"""
An enumeration for different marking modes.
"""
"""An enumeration for different marking modes."""

NUMERIC = "NUMERIC"
ALPHABETIC = "ALPHABETIC"
Expand Down
17 changes: 9 additions & 8 deletions maestro/trainer/common/data_loaders/datasets.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import json
import os
from typing import Any, Dict, List, Tuple
from typing import Any

from PIL import Image
from transformers.pipelines.base import Dataset


class JSONLDataset:
def __init__(self, jsonl_file_path: str, image_directory_path: str):
def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None:
self.jsonl_file_path = jsonl_file_path
self.image_directory_path = image_directory_path
self.entries = self._load_entries()

def _load_entries(self) -> List[Dict[str, Any]]:
def _load_entries(self) -> list[dict[str, Any]]:
entries = []
with open(self.jsonl_file_path, "r") as file:
with open(self.jsonl_file_path) as file:
for line in file:
data = json.loads(line)
entries.append(data)
Expand All @@ -23,24 +23,25 @@ def _load_entries(self) -> List[Dict[str, Any]]:
def __len__(self) -> int:
return len(self.entries)

def __getitem__(self, idx: int) -> Tuple[Image.Image, Dict[str, Any]]:
def __getitem__(self, idx: int) -> tuple[Image.Image, dict[str, Any]]:
if idx < 0 or idx >= len(self.entries):
raise IndexError("Index out of range")

entry = self.entries[idx]
image_path = os.path.join(self.image_directory_path, entry["image"])
try:
image = Image.open(image_path)
return (image, entry)
except FileNotFoundError:
raise FileNotFoundError(f"Image file {image_path} not found.")
else:
return (image, entry)


class DetectionDataset(Dataset):
def __init__(self, jsonl_file_path: str, image_directory_path: str):
def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None:
self.dataset = JSONLDataset(jsonl_file_path, image_directory_path)

def __len__(self):
def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx):
Expand Down
7 changes: 3 additions & 4 deletions maestro/trainer/common/data_loaders/jsonl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import random
from typing import List

from torch.utils.data import Dataset

Expand All @@ -18,14 +17,14 @@ def from_jsonl_file(cls, path: str) -> JSONLDataset:
random.shuffle(file_content)
return cls(jsons=file_content)

def __init__(self, jsons: List[dict]):
def __init__(self, jsons: list[dict]) -> None:
self.jsons = jsons

def __getitem__(self, index):
return self.jsons[index]

def __len__(self):
def __len__(self) -> int:
return len(self.jsons)

def shuffle(self):
def shuffle(self) -> None:
random.shuffle(self.jsons)
11 changes: 5 additions & 6 deletions maestro/trainer/common/utils/file_system.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import os
from glob import glob
from typing import List, Union
from typing import Union


def read_jsonl(path: str) -> List[dict]:
def read_jsonl(path: str) -> list[dict]:
file_lines = read_file(
path=path,
split_lines=True,
Expand All @@ -17,8 +17,8 @@ def read_file(
split_lines: bool = False,
strip_white_spaces: bool = False,
line_separator: str = "\n",
) -> Union[str, List[str]]:
with open(path, "r") as f:
) -> Union[str, list[str]]:
with open(path) as f:
file_content = f.read()
if strip_white_spaces:
file_content = file_content.strip()
Expand All @@ -42,8 +42,7 @@ def ensure_parent_dir_exists(path: str) -> None:


def create_new_run_directory(base_output_dir: str) -> str:
"""
Creates a new numbered directory for the current training run.
"""Creates a new numbered directory for the current training run.
Args:
base_output_dir (str): The base directory where all run directories are stored.
Expand Down
8 changes: 4 additions & 4 deletions maestro/trainer/common/utils/leaderboard.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Dict, Optional, Tuple
from typing import Optional


class CheckpointsLeaderboard:
def __init__(
self,
max_checkpoints: int,
):
) -> None:
self._max_checkpoints = max(max_checkpoints, 1)
self._leaderboard: Dict[int, Tuple[str, float]] = {}
self._leaderboard: dict[int, tuple[str, float]] = {}

def register_checkpoint(self, epoch: int, path: str, loss: float) -> Tuple[bool, Optional[str]]:
def register_checkpoint(self, epoch: int, path: str, loss: float) -> tuple[bool, Optional[str]]:
if len(self._leaderboard) < self._max_checkpoints:
self._leaderboard[epoch] = (path, loss)
return True, None
Expand Down
Loading

0 comments on commit 9f6d990

Please sign in to comment.