Skip to content

Commit

Permalink
modify detection node not to use srv to speed up
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Mar 16, 2024
1 parent dc4365a commit 7cc44b8
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 53 deletions.
89 changes: 63 additions & 26 deletions tracking_ros/node_scripts/grounding_dino_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torchvision
import supervision as sv
import numpy as np
import rospy

from cv_bridge import CvBridge
Expand All @@ -15,8 +16,10 @@
from jsk_recognition_msgs.msg import ClassificationResult
from tracking_ros_utils.srv import SamPrompt, SamPromptRequest

from segment_anything.utils.amg import remove_small_regions
from tracking_ros.cfg import GroundingDINOConfig as ServerConfig
from model_config import GroundingDINOConfig
from model_config import GroundingDINOConfig, SAMConfig
from utils import overlay_davis

BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
Expand All @@ -29,6 +32,13 @@ def __init__(self):
self.gd_config = GroundingDINOConfig.from_rosparam()
self.predictor = self.gd_config.get_predictor()
self.get_mask = rospy.get_param("~get_mask", False)
if self.get_mask:
self.sam_config = SAMConfig.from_rosparam()
self.sam_predictor = self.sam_config.get_predictor()
self.refine_mask = rospy.get_param("~refine_mask", False)
if self.refine_mask:
self.area_threshold = rospy.get_param("~area_threshold", 400)
self.refine_mode = rospy.get_param("~refine_mode", "holes") # "holes" or "islands"

self.bridge = CvBridge()
self.pub_vis_img = self.advertise("~output/debug_image", Image, queue_size=1)
Expand Down Expand Up @@ -132,34 +142,61 @@ def callback(self, img_msg):
scores = detections.confidence.tolist()
labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)]

self.visualization = self.image.copy()
self.segmentation = None
visualization = self.image.copy()
segmentation = None
if self.get_mask and len(detections.xyxy) > 0:
rospy.wait_for_service("/sam_node/process_prompt")
try:
prompt = SamPromptRequest()
prompt.image = img_msg
for xyxy in detections.xyxy:
rect = Rect()
rect.x = int(xyxy[0]) # x1
rect.y = int(xyxy[1]) # y1
rect.width = int(xyxy[2] - xyxy[0]) # x2 - x1
rect.height = int(xyxy[3] - xyxy[1]) # y2 - y1
prompt.boxes.append(rect)
prompt_response = rospy.ServiceProxy("/sam_node/process_prompt", SamPrompt)
res = prompt_response(prompt)
seg_msg, vis_img_msg = res.segmentation, res.segmentation_image
self.segmentation = self.bridge.imgmsg_to_cv2(seg_msg, desired_encoding="32SC1")
self.visualization = self.bridge.imgmsg_to_cv2(vis_img_msg, desired_encoding="rgb8")
except rospy.ServiceException as e:
rospy.logerr(f"Service call failed: {e}")
self.visualization = BOX_ANNOTATOR.annotate(scene=self.visualization, detections=detections)
self.visualization = LABEL_ANNOTATOR.annotate(
scene=self.visualization, detections=detections, labels=labels_with_scores
result_mask = None
for i, box in enumerate(detections.xyxy):
mask, logit = self.process_prompt(
points=None,
labels=None,
bbox=np.array([box[0], box[1], box[2], box[3]]),
multimask=False,
)
if result_mask is None:
result_mask = mask.astype(np.uint8)
else:
result_mask[mask] = i + 1
visualization = self.image.copy()
if result_mask is not None:
visualization = overlay_davis(visualization, result_mask)
segmentation = result_mask.astype(np.int32)
visualization = BOX_ANNOTATOR.annotate(scene=visualization, detections=detections)
visualization = LABEL_ANNOTATOR.annotate(
scene=visualization, detections=detections, labels=labels_with_scores
)
self.publish_result(
detections.xyxy, labels, scores, self.segmentation, self.visualization, img_msg.header.frame_id
self.publish_result(detections.xyxy, labels, scores, segmentation, visualization, img_msg.header.frame_id)

def process_prompt(
self,
points=None,
bbox=None,
labels=None,
mask_input=None,
multimask: bool = True,
):
self.sam_predictor.set_image(self.image)
masks, scores, logits = self.sam_predictor.predict(
point_coords=points,
point_labels=labels,
box=bbox,
mask_input=mask_input, # TODO
multimask_output=multimask,
) # [N, H, W], B : number of prompts, N : number of masks recommended
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores)]

if self.refine_mask:
# refine mask using logit
masks, scores, logits = self.sam_predictor.predict(
point_coords=points,
point_labels=labels,
box=bbox,
mask_input=logit[None, :, :],
multimask_output=multimask,
)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores)]
mask, _ = remove_small_regions(mask, self.area_threshold, mode=self.refine_mode)
return mask, logit


if __name__ == "__main__":
Expand Down
90 changes: 63 additions & 27 deletions tracking_ros/node_scripts/yolo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-

import supervision as sv
import numpy as np
import rospy

from cv_bridge import CvBridge
Expand All @@ -11,10 +12,11 @@
from jsk_recognition_msgs.msg import Rect, RectArray
from jsk_recognition_msgs.msg import Label, LabelArray
from jsk_recognition_msgs.msg import ClassificationResult
from tracking_ros_utils.srv import SamPrompt, SamPromptRequest

from segment_anything.utils.amg import remove_small_regions
from tracking_ros.cfg import YOLOConfig as ServerConfig
from model_config import YOLOConfig
from model_config import YOLOConfig, SAMConfig
from utils import overlay_davis

BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
Expand All @@ -27,6 +29,13 @@ def __init__(self):
self.predictor = self.yolo_config.get_predictor()
self.reconfigure_server = Server(ServerConfig, self.config_cb)
self.get_mask = rospy.get_param("~get_mask", False)
if self.get_mask:
self.sam_config = SAMConfig.from_rosparam()
self.sam_predictor = self.sam_config.get_predictor()
self.refine_mask = rospy.get_param("~refine_mask", False)
if self.refine_mask:
self.area_threshold = rospy.get_param("~area_threshold", 400)
self.refine_mode = rospy.get_param("~refine_mode", "holes") # "holes" or "islands"

self.bridge = CvBridge()
self.pub_vis_img = self.advertise("~output/debug_image", Image, queue_size=1)
Expand Down Expand Up @@ -112,34 +121,61 @@ def callback(self, img_msg):
scores = detections.confidence.tolist()
labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)]

self.visualization = self.image.copy()
self.segmentation = None
visualization = self.image.copy()
segmentation = None
if self.get_mask and len(detections.xyxy) > 0:
rospy.wait_for_service("/sam_node/process_prompt")
try:
prompt = SamPromptRequest()
prompt.image = img_msg
for xyxy in detections.xyxy:
rect = Rect()
rect.x = int(xyxy[0]) # x1
rect.y = int(xyxy[1]) # y1
rect.width = int(xyxy[2] - xyxy[0]) # x2 - x1
rect.height = int(xyxy[3] - xyxy[1]) # y2 - y1
prompt.boxes.append(rect)
prompt_response = rospy.ServiceProxy("/sam_node/process_prompt", SamPrompt)
res = prompt_response(prompt)
seg_msg, vis_img_msg = res.segmentation, res.segmentation_image
self.segmentation = self.bridge.imgmsg_to_cv2(seg_msg, desired_encoding="32SC1")
self.visualization = self.bridge.imgmsg_to_cv2(vis_img_msg, desired_encoding="bgr8")
except rospy.ServiceException as e:
rospy.logerr(f"Service call failed: {e}")
self.visualization = BOX_ANNOTATOR.annotate(scene=self.visualization, detections=detections)
self.visualization = LABEL_ANNOTATOR.annotate(
scene=self.visualization, detections=detections, labels=labels_with_scores
result_mask = None
for i, box in enumerate(detections.xyxy):
mask, logit = self.process_prompt(
points=None,
labels=None,
bbox=np.array([box[0], box[1], box[2], box[3]]),
multimask=False,
)
if result_mask is None:
result_mask = mask.astype(np.uint8)
else:
result_mask[mask] = i + 1
visualization = self.image.copy()
if result_mask is not None:
visualization = overlay_davis(visualization, result_mask)
segmentation = result_mask.astype(np.int32)
visualization = BOX_ANNOTATOR.annotate(scene=visualization, detections=detections)
visualization = LABEL_ANNOTATOR.annotate(
scene=visualization, detections=detections, labels=labels_with_scores
)
self.publish_result(
detections.xyxy, labels, scores, self.segmentation, self.visualization, img_msg.header.frame_id
self.publish_result(detections.xyxy, labels, scores, segmentation, visualization, img_msg.header.frame_id)

def process_prompt(
self,
points=None,
bbox=None,
labels=None,
mask_input=None,
multimask: bool = True,
):
self.sam_predictor.set_image(self.image)
masks, scores, logits = self.sam_predictor.predict(
point_coords=points,
point_labels=labels,
box=bbox,
mask_input=mask_input, # TODO
multimask_output=multimask,
) # [N, H, W], B : number of prompts, N : number of masks recommended
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores)]

if self.refine_mask:
# refine mask using logit
masks, scores, logits = self.sam_predictor.predict(
point_coords=points,
point_labels=labels,
box=bbox,
mask_input=logit[None, :, :],
multimask_output=multimask,
)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores)]
mask, _ = remove_small_regions(mask, self.area_threshold, mode=self.refine_mode)
return mask, logit


if __name__ == "__main__":
Expand Down

0 comments on commit 7cc44b8

Please sign in to comment.