diff --git a/README.md b/README.md index 54ae80e..dda763e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Multimodal-Maestro gives you more control over large multimodal models to get the outputs you want. With more effective prompting tactics, you can get multimodal models to do tasks you didn't know (or think!) were possible. Curious how it works? Try our -HF [space](https://huggingface.co/spaces/Roboflow/SoM)! +[HF space](https://huggingface.co/spaces/Roboflow/SoM)! 🚧 The project is still under construction, and the API is prone to change. @@ -109,10 +109,9 @@ Find dog. ## 🚧 roadmap - [ ] Rewriting the `maestro` API. +- [ ] Update [HF space](https://huggingface.co/spaces/Roboflow/SoM). - [ ] Documentation page. - [ ] Add GroundingDINO prompting strategy. -- [ ] Segment Anything guided marks generation. -- [ ] Non-Max Suppression marks refinement. - [ ] CovVLM demo. - [ ] Qwen-VL demo. diff --git a/maestro/markers/sam.py b/maestro/markers/sam.py index 69183f8..17a2566 100644 --- a/maestro/markers/sam.py +++ b/maestro/markers/sam.py @@ -3,7 +3,7 @@ import supervision as sv from PIL import Image from transformers import pipeline, SamModel, SamProcessor, SamImageProcessor -from typing import Union +from typing import Optional from maestro.postprocessing.mask import masks_to_marks @@ -15,30 +15,60 @@ class SegmentAnythingMarkGenerator: Parameters: device (str): The device to run the model on (e.g., 'cpu', 'cuda'). model_name (str): The name of the model to be loaded. Defaults to - 'facebook/sam-vit-huge'. + 'facebook/sam-vit-huge'. """ def __init__(self, device: str = 'cpu', model_name: str = "facebook/sam-vit-huge"): self.model = SamModel.from_pretrained(model_name).to(device) self.processor = SamProcessor.from_pretrained(model_name) self.image_processor = SamImageProcessor.from_pretrained(model_name) + self.device = device self.pipeline = pipeline( task="mask-generation", model=self.model, image_processor=self.image_processor, - device=device) + device=self.device) - def generate(self, image: np.ndarray) -> sv.Detections: + def generate( + self, + image: np.ndarray, + mask: Optional[np.ndarray] = None + ) -> sv.Detections: """ Generate image segmentation marks. Parameters: image (np.ndarray): The image to be marked in BGR format. + mask: (Optional[np.ndarray]): The mask to be used as a guide for + segmentation. Returns: sv.Detections: An object containing the segmentation masks and their corresponding bounding box coordinates. """ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - outputs = self.pipeline(image, points_per_batch=64) - masks = np.array(outputs['masks']) - return masks_to_marks(masks=masks) + if mask is None: + outputs = self.pipeline(image, points_per_batch=64) + masks = np.array(outputs['masks']) + return masks_to_marks(masks=masks) + else: + inputs = self.processor(image, return_tensors="pt").to(self.device) + image_embeddings = self.model.get_image_embeddings(inputs.pixel_values) + masks = [] + for polygon in sv.mask_to_polygons(mask.astype(bool)): + indexes = np.random.choice(a=polygon.shape[0], size=5, replace=True) + input_points = polygon[indexes] + inputs = self.processor( + images=image, + input_points=[[input_points]], + return_tensors="pt" + ).to(self.device) + del inputs["pixel_values"] + outputs = self.model(image_embeddings=image_embeddings, **inputs) + mask = self.processor.image_processor.post_process_masks( + masks=outputs.pred_masks.cpu().detach(), + original_sizes=inputs["original_sizes"].cpu().detach(), + reshaped_input_sizes=inputs["reshaped_input_sizes"].cpu().detach() + )[0][0][0].numpy() + masks.append(mask) + masks = np.array(masks) + return masks_to_marks(masks=masks) diff --git a/maestro/postprocessing/mask.py b/maestro/postprocessing/mask.py index 7f09b40..b028495 100644 --- a/maestro/postprocessing/mask.py +++ b/maestro/postprocessing/mask.py @@ -78,9 +78,8 @@ def mask_non_max_suppression( overlapping_masks = iou_matrix[sorted_idx[i]] > iou_threshold overlapping_masks[sorted_idx[i]] = False - keep_mask[sorted_idx] = np.logical_and( - keep_mask[sorted_idx], - ~overlapping_masks) + overlapping_indices = np.where(overlapping_masks)[0] + keep_mask[sorted_idx[overlapping_indices]] = False return masks[keep_mask] diff --git a/pyproject.toml b/pyproject.toml index 5104d3c..26a2983 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "maestro" -version = "0.1.0" +version = "0.1.1rc1" description = "Visual Prompting for Large Multimodal Models (LMMs)" authors = ["Piotr Skalski "] maintainers = ["Piotr Skalski "]