Skip to content

Commit

Permalink
Merge pull request #15 from roboflow/feature/guided_sam_mask_generation
Browse files Browse the repository at this point in the history
feature/guided sam mask generation
  • Loading branch information
SkalskiP authored Dec 1, 2023
2 parents 971dc70 + 392f424 commit 70b9d7c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Multimodal-Maestro gives you more control over large multimodal models to get the
outputs you want. With more effective prompting tactics, you can get multimodal models
to do tasks you didn't know (or think!) were possible. Curious how it works? Try our
HF [space](https://huggingface.co/spaces/Roboflow/SoM)!
[HF space](https://huggingface.co/spaces/Roboflow/SoM)!

🚧 The project is still under construction, and the API is prone to change.

Expand Down Expand Up @@ -109,10 +109,9 @@ Find dog.
## 🚧 roadmap

- [ ] Rewriting the `maestro` API.
- [ ] Update [HF space](https://huggingface.co/spaces/Roboflow/SoM).
- [ ] Documentation page.
- [ ] Add GroundingDINO prompting strategy.
- [ ] Segment Anything guided marks generation.
- [ ] Non-Max Suppression marks refinement.
- [ ] CovVLM demo.
- [ ] Qwen-VL demo.

Expand Down
44 changes: 37 additions & 7 deletions maestro/markers/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import supervision as sv
from PIL import Image
from transformers import pipeline, SamModel, SamProcessor, SamImageProcessor
from typing import Union
from typing import Optional

from maestro.postprocessing.mask import masks_to_marks

Expand All @@ -15,30 +15,60 @@ class SegmentAnythingMarkGenerator:
Parameters:
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
model_name (str): The name of the model to be loaded. Defaults to
'facebook/sam-vit-huge'.
'facebook/sam-vit-huge'.
"""
def __init__(self, device: str = 'cpu', model_name: str = "facebook/sam-vit-huge"):
self.model = SamModel.from_pretrained(model_name).to(device)
self.processor = SamProcessor.from_pretrained(model_name)
self.image_processor = SamImageProcessor.from_pretrained(model_name)
self.device = device
self.pipeline = pipeline(
task="mask-generation",
model=self.model,
image_processor=self.image_processor,
device=device)
device=self.device)

def generate(self, image: np.ndarray) -> sv.Detections:
def generate(
self,
image: np.ndarray,
mask: Optional[np.ndarray] = None
) -> sv.Detections:
"""
Generate image segmentation marks.
Parameters:
image (np.ndarray): The image to be marked in BGR format.
mask: (Optional[np.ndarray]): The mask to be used as a guide for
segmentation.
Returns:
sv.Detections: An object containing the segmentation masks and their
corresponding bounding box coordinates.
"""
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
outputs = self.pipeline(image, points_per_batch=64)
masks = np.array(outputs['masks'])
return masks_to_marks(masks=masks)
if mask is None:
outputs = self.pipeline(image, points_per_batch=64)
masks = np.array(outputs['masks'])
return masks_to_marks(masks=masks)
else:
inputs = self.processor(image, return_tensors="pt").to(self.device)
image_embeddings = self.model.get_image_embeddings(inputs.pixel_values)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
indexes = np.random.choice(a=polygon.shape[0], size=5, replace=True)
input_points = polygon[indexes]
inputs = self.processor(
images=image,
input_points=[[input_points]],
return_tensors="pt"
).to(self.device)
del inputs["pixel_values"]
outputs = self.model(image_embeddings=image_embeddings, **inputs)
mask = self.processor.image_processor.post_process_masks(
masks=outputs.pred_masks.cpu().detach(),
original_sizes=inputs["original_sizes"].cpu().detach(),
reshaped_input_sizes=inputs["reshaped_input_sizes"].cpu().detach()
)[0][0][0].numpy()
masks.append(mask)
masks = np.array(masks)
return masks_to_marks(masks=masks)
5 changes: 2 additions & 3 deletions maestro/postprocessing/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def mask_non_max_suppression(

overlapping_masks = iou_matrix[sorted_idx[i]] > iou_threshold
overlapping_masks[sorted_idx[i]] = False
keep_mask[sorted_idx] = np.logical_and(
keep_mask[sorted_idx],
~overlapping_masks)
overlapping_indices = np.where(overlapping_masks)[0]
keep_mask[sorted_idx[overlapping_indices]] = False

return masks[keep_mask]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "maestro"
version = "0.1.0"
version = "0.1.1rc1"
description = "Visual Prompting for Large Multimodal Models (LMMs)"
authors = ["Piotr Skalski <piotr.skalski92@gmail.com>"]
maintainers = ["Piotr Skalski <piotr.skalski92@gmail.com>"]
Expand Down

0 comments on commit 70b9d7c

Please sign in to comment.