Skip to content

Commit

Permalink
Merge pull request #755 from roboflow/owl-v2-improvements
Browse files Browse the repository at this point in the history
Owlvit model enhancements
  • Loading branch information
probicheaux authored Oct 31, 2024
2 parents 2cf0efa + 877554d commit ada5c76
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 69 deletions.
4 changes: 4 additions & 0 deletions inference/core/entities/requests/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class TrainBox(BaseModel):
w: int = Field(description="Width in pixels of train box")
h: int = Field(description="Height in pixels of train box")
cls: str = Field(description="Class name of object this box encloses")
negative: bool = Field(
default=False,
description="Whether this object is a positive or negative example for this class",
)


class TrainingImage(BaseModel):
Expand Down
244 changes: 176 additions & 68 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import os
from collections import defaultdict
from typing import Dict, List, NewType, Tuple
from typing import Any, Dict, List, Literal, NewType, Tuple

import numpy as np
import torch
Expand All @@ -15,15 +15,19 @@
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import DEVICE
from inference.core.env import DEVICE, MAX_DETECTIONS
from inference.core.models.roboflow import (
DEFAULT_COLOR_PALETTE,
RoboflowCoreModel,
draw_detection_predictions,
)
from inference.core.utils.image_utils import load_image_rgb

# TYPES
Hash = NewType("Hash", str)
PosNegKey = Literal["positive", "negative"]
PosNegDictType = Dict[PosNegKey, torch.Tensor]
QuerySpecType = Dict[Hash, List[List[int]]]
if DEVICE is None:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -101,6 +105,93 @@ def preprocess_image(
return padded_image_tensor


def filter_tensors_by_objectness(
objectness: torch.Tensor,
boxes: torch.Tensor,
image_class_embeds: torch.Tensor,
logit_shift: torch.Tensor,
logit_scale: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
objectness = objectness.squeeze(0)
objectness, objectness_indices = torch.topk(objectness, MAX_DETECTIONS, dim=0)
boxes = boxes.squeeze(0)
image_class_embeds = image_class_embeds.squeeze(0)
logit_shift = logit_shift.squeeze(0).squeeze(1)
logit_scale = logit_scale.squeeze(0).squeeze(1)
boxes = boxes[objectness_indices]
image_class_embeds = image_class_embeds[objectness_indices]
logit_shift = logit_shift[objectness_indices]
logit_scale = logit_scale[objectness_indices]
return objectness, boxes, image_class_embeds, logit_shift, logit_scale


def get_class_preds_from_embeds(
pos_neg_embedding_dict: PosNegDictType,
image_class_embeds: torch.Tensor,
confidence: torch.Tensor,
image_boxes: torch.Tensor,
class_map: torch.Tensor,
class_name: torch.Tensor,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
predicted_boxes_per_class = []
predicted_class_indices_per_class = []
predicted_scores_per_class = []
positive_arr_per_class = []
for positive, embedding in pos_neg_embedding_dict.items():
if embedding is None:
continue
pred_logits = torch.einsum("sd,nd->ns", image_class_embeds, embedding)
prediction_scores = pred_logits.max(dim=0)[0]
prediction_scores = (prediction_scores + 1) / 2
score_mask = prediction_scores > confidence
predicted_boxes_per_class.append(image_boxes[score_mask])
scores = prediction_scores[score_mask]
predicted_scores_per_class.append(scores)
class_ind = class_map[(class_name, positive)]
predicted_class_indices_per_class.append(class_ind * torch.ones_like(scores))
positive_arr_per_class.append(
int(positive == "positive") * torch.ones_like(scores)
)

if not predicted_boxes_per_class:
return (
np.empty((0, 4)),
np.empty((0,)),
np.empty((0,)),
)

# concat tensors
pred_boxes = torch.cat(predicted_boxes_per_class, dim=0).float()
pred_classes = torch.cat(predicted_class_indices_per_class, dim=0).float()
pred_scores = torch.cat(predicted_scores_per_class, dim=0).float()
positive = torch.cat(positive_arr_per_class, dim=0).float()
# nms
survival_indices = torchvision.ops.nms(to_corners(pred_boxes), pred_scores, 0.3)
# put on numpy and filter to post-nms
pred_boxes = pred_boxes[survival_indices, :].detach().cpu().numpy()
pred_classes = pred_classes[survival_indices].detach().cpu().numpy()
pred_scores = pred_scores[survival_indices].detach().cpu().numpy()
positive = positive[survival_indices].detach().cpu().numpy()
is_positive = positive == 1
# return only positive elements of tensor
return pred_boxes[is_positive], pred_classes[is_positive], pred_scores[is_positive]


def make_class_map(
query_embeddings: Dict[str, PosNegDictType]
) -> Tuple[Dict[Tuple[str, str], int], List[str]]:
class_names = sorted(list(query_embeddings.keys()))
class_map_positive = {
(class_name, "positive"): i for i, class_name in enumerate(class_names)
}
class_map_negative = {
(class_name, "negative"): i + len(class_names)
for i, class_name in enumerate(class_names)
}
class_map = {**class_map_positive, **class_map_negative}
return class_map, class_names


class OwlV2(RoboflowCoreModel):
task_type = "object-detection"
box_format = "xywh"
Expand Down Expand Up @@ -200,22 +291,28 @@ def embed_image(self, image: np.ndarray) -> Hash:
)
objectness = objectness.sigmoid()

objectness, boxes, image_class_embeds, logit_shift, logit_scale = (
filter_tensors_by_objectness(
objectness, boxes, image_class_embeds, logit_shift, logit_scale
)
)

self.image_embed_cache[image_hash] = (
objectness.squeeze(0),
boxes.squeeze(0),
image_class_embeds.squeeze(0),
logit_shift.squeeze(0).squeeze(1),
logit_scale.squeeze(0).squeeze(1),
objectness,
boxes,
image_class_embeds,
logit_shift,
logit_scale,
)

return image_hash

def get_query_embedding(self, query_spec: Dict[Hash, List[List[int]]]):
def get_query_embedding(self, query_spec: QuerySpecType) -> torch.Tensor:
# NOTE: for now we're handling each image seperately
query_embeds = []
for image_hash, query_boxes in query_spec.items():
try:
objectness, image_boxes, image_class_embeds, _, _ = (
_objectness, image_boxes, image_class_embeds, _, _ = (
self.image_embed_cache[image_hash]
)
except KeyError as error:
Expand All @@ -224,54 +321,45 @@ def get_query_embedding(self, query_spec: Dict[Hash, List[List[int]]]):
query_boxes_tensor = torch.tensor(
query_boxes, dtype=image_boxes.dtype, device=image_boxes.device
)
iou, union = box_iou(
iou, _ = box_iou(
to_corners(image_boxes), to_corners(query_boxes_tensor)
) # 3000, k
iou_mask = iou > 0.4
valid_objectness = torch.where(
iou_mask, objectness.unsqueeze(-1), -1
) # 3000, k
if torch.all(iou_mask == 0):
ious, indices = torch.max(iou, dim=0)
# filter for only iou > 0.4
iou_mask = ious > 0.4
indices = indices[iou_mask]
if not indices.numel() > 0:
continue
else:
indices = torch.argmax(valid_objectness, dim=0)
embeds = image_class_embeds[indices]
query_embeds.append(embeds)

embeds = image_class_embeds[indices]
query_embeds.append(embeds)
if not query_embeds:
return None
query = torch.cat(query_embeds).mean(dim=0)
query /= torch.linalg.norm(query, ord=2) + 1e-6
query = torch.cat(query_embeds, dim=0)
return query

def infer_from_embed(self, image_hash: Hash, query_embeddings, confidence):
objectness, image_boxes, image_class_embeds, logit_shift, logit_scale = (
self.image_embed_cache[image_hash]
)
predicted_boxes = []
predicted_classes = []
predicted_scores = []
class_names = sorted(list(query_embeddings.keys()))
class_map = {class_name: i for i, class_name in enumerate(class_names)}
for class_name, embedding in query_embeddings.items():
if embedding is None:
continue
pred_logits = torch.einsum("sd,d->s", image_class_embeds, embedding)
pred_logits = (pred_logits + logit_shift) * logit_scale
prediction_scores = pred_logits.sigmoid()
score_mask = prediction_scores > confidence
predicted_boxes.append(image_boxes[score_mask, :])
scores = prediction_scores[score_mask]
predicted_scores.append(scores)
class_ind = class_map[class_name]
predicted_classes.append(class_ind * torch.ones_like(scores))

all_boxes = torch.cat(predicted_boxes, dim=0).float()
all_classes = torch.cat(predicted_classes, dim=0).float()
all_scores = torch.cat(predicted_scores, dim=0).float()
survival_indices = torchvision.ops.nms(to_corners(all_boxes), all_scores, 0.3)
pred_boxes = all_boxes[survival_indices].detach().cpu().numpy()
pred_classes = all_classes[survival_indices].detach().cpu().numpy()
pred_scores = all_scores[survival_indices].detach().cpu().numpy()
def infer_from_embed(
self,
image_hash: Hash,
query_embeddings: Dict[str, PosNegDictType],
confidence: float,
) -> List[Dict]:
_, image_boxes, image_class_embeds, _, _ = self.image_embed_cache[image_hash]
class_map, class_names = make_class_map(query_embeddings)
all_predicted_boxes, all_predicted_classes, all_predicted_scores = [], [], []
for class_name, pos_neg_embedding_dict in query_embeddings.items():
boxes, classes, scores = get_class_preds_from_embeds(
pos_neg_embedding_dict,
image_class_embeds,
confidence,
image_boxes,
class_map,
class_name,
)

all_predicted_boxes.extend(boxes)
all_predicted_classes.extend(classes)
all_predicted_scores.extend(scores)
return [
{
"class_name": class_names[int(c)],
Expand All @@ -281,25 +369,15 @@ def infer_from_embed(self, image_hash: Hash, query_embeddings, confidence):
"h": float(h),
"confidence": float(score),
}
for c, (x, y, w, h), score in zip(pred_classes, pred_boxes, pred_scores)
for c, (x, y, w, h), score in zip(
all_predicted_classes, all_predicted_boxes, all_predicted_scores
)
]

def infer(self, image, training_data, confidence=0.99, **kwargs):
class_to_query_spec = defaultdict(lambda: defaultdict(list))
for train_image_dict in training_data:
boxes, train_image = train_image_dict["boxes"], train_image_dict["image"]
train_image = load_image_rgb(train_image)
image_hash = self.embed_image(train_image)
for box in boxes:
class_name = box["cls"]
coords = box["x"], box["y"], box["w"], box["h"]
coords = tuple([c / max(train_image.shape[:2]) for c in coords])
class_to_query_spec[class_name][image_hash].append(coords)
def infer(self, image: Any, training_data: Dict, confidence=0.99, **kwargs):
class_to_query_spec = self.make_class_box_query_dict(training_data)

my_class_to_embeddings_dict = dict()
for class_name, query_spec in class_to_query_spec.items():
class_embedding = self.get_query_embedding(query_spec)
my_class_to_embeddings_dict[class_name] = class_embedding
class_embeddings_dict = self.make_class_embeddings_dict(class_to_query_spec)

if not isinstance(image, list):
images = [image]
Expand All @@ -313,12 +391,42 @@ def infer(self, image, training_data, confidence=0.99, **kwargs):
image_sizes.append(image.shape[:2][::-1])
image_hash = self.embed_image(image)
result = self.infer_from_embed(
image_hash, my_class_to_embeddings_dict, confidence
image_hash, class_embeddings_dict, confidence
)
results.append(result)
return self.make_response(
results, image_sizes, sorted(list(my_class_to_embeddings_dict.keys()))
results, image_sizes, sorted(list(class_embeddings_dict.keys()))
)

def make_class_embeddings_dict(
self, class_to_query_spec: Dict[Tuple[str, str], Dict]
) -> Dict[str, PosNegDictType]:
class_embeddings_dict = defaultdict(
lambda: {"positive": None, "negative": None}
)
bool_to_literal = {True: "positive", False: "negative"}
for (class_name, positive), query_spec in class_to_query_spec.items():
class_embedding = self.get_query_embedding(query_spec)
class_embeddings_dict[class_name][
bool_to_literal[positive]
] = class_embedding

return class_embeddings_dict

def make_class_box_query_dict(self, training_data):
class_to_query_spec = defaultdict(lambda: defaultdict(list))
for train_image_dict in training_data:
boxes, train_image = train_image_dict["boxes"], train_image_dict["image"]
train_image = load_image_rgb(train_image)
image_hash = self.embed_image(train_image)
for box in boxes:
negative = box["negative"]
positive = not negative
class_name = box["cls"]
coords = box["x"], box["y"], box["w"], box["h"]
coords = tuple([c / max(train_image.shape[:2]) for c in coords])
class_to_query_spec[(class_name, positive)][image_hash].append(coords)
return class_to_query_spec

def make_response(self, predictions, image_sizes, class_names):
responses = [
Expand Down
2 changes: 1 addition & 1 deletion tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_owlv2():
training_data=[
{
"image": image,
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}],
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}],
}
],
visualize_predictions=True,
Expand Down

0 comments on commit ada5c76

Please sign in to comment.