Skip to content

Commit

Permalink
refactor code to use catkin python package and fix formatting param
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jul 24, 2024
1 parent be80e11 commit f87758c
Show file tree
Hide file tree
Showing 16 changed files with 747 additions and 681 deletions.
5 changes: 4 additions & 1 deletion tracking_ros/launch/grounding_detection.launch
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
<arg name="with_bbox" default="true" /> <!-- false if faster, but for visualization -->
<arg name="refine_mask" default="true" /> <!-- refine mask predicts twice to refine mask -->

<arg name="_get_mask" value="true" if="$(arg track)"/>
<arg name="_get_mask" value="$(arg get_mask)" unless="$(arg track)"/>

<!-- grounding_dino_node -->
<node name="grounding_dino_node"
pkg="tracking_ros" type="grounding_dino_node.py"
output="screen" >
<remap from="~input_image" to="$(arg input_image)" />
<rosparam subst_value="true" >
device: $(arg device)
get_mask: $(arg get_mask)
get_mask: $(arg _get_mask)
model_type: $(arg model_type)
refine_mask: $(arg refine_mask)
</rosparam>
Expand Down
5 changes: 4 additions & 1 deletion tracking_ros/launch/yolo_detection.launch
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
<arg name="with_bbox" default="true" /> <!-- false if faster, but for visualization -->
<arg name="refine_mask" default="true" /> <!-- refine mask predicts twice to refine mask -->

<arg name="_get_mask" value="true" if="$(arg track)"/>
<arg name="_get_mask" value="$(arg get_mask)" unless="$(arg track)"/>

<!-- yolo_world_node -->
<node name="yolo_node"
pkg="tracking_ros" type="yolo_node.py"
output="screen" >
<remap from="~input_image" to="$(arg input_image)" />
<rosparam subst_value="true" >
device: $(arg device)
get_mask: $(arg get_mask)
get_mask: $(arg _get_mask)
model_id: $(arg model_id)
model_type: $(arg model_type)
refine_mask: $(arg refine_mask)
Expand Down
75 changes: 11 additions & 64 deletions tracking_ros/node_scripts/cutie_node.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn.functional as F
import supervision as sv

import rospy
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_srvs.srv import Empty, EmptyResponse
from tracking_ros_utils.srv import CutiePrompt, CutiePromptResponse

from tracking_ros.model_config import CutieConfig
from tracking_ros.utils import overlay_davis

BOX_ANNOTATOR = sv.BoxAnnotator()
from tracking_ros.model_wrapper import CutieModel


class CutieNode(object): # should not be ConnectionBasedNode cause Cutie tracker needs continuous input
Expand Down Expand Up @@ -48,8 +41,8 @@ def __init__(self):
def get_oneshot_prompt(self):
try:
# oneshot subscribe initial image and segmentation
input_seg_msg = rospy.wait_for_message("~input_segmentation", Image, timeout=5)
input_img_msg = rospy.wait_for_message("~input_image", Image, timeout=5)
input_seg_msg = rospy.wait_for_message("~input_segmentation", Image)
input_img_msg = rospy.wait_for_message("~input_image", Image)
mask = self.bridge.imgmsg_to_cv2(input_seg_msg, desired_encoding="32SC1")
image = self.bridge.imgmsg_to_cv2(input_img_msg, desired_encoding="rgb8")
return image, mask
Expand All @@ -68,43 +61,24 @@ def prompt_service_callback(self, req):
return CutiePromptResponse(result=True)

def reset_callback(self, req):
rospy.loginfo("Resetting Cutie tracker")
self.track_flag = False
rospy.loginfo("Resetting Cutie tracker")
image, mask = self.get_oneshot_prompt()
self.initialize(image, mask)
self.track_flag = self.initialize(image, mask)
return EmptyResponse()

@torch.inference_mode()
def initialize(self, image, mask):
if image is None or mask is None:
return False

self.cutie_config = CutieConfig.from_rosparam()
self.predictor = self.cutie_config.get_predictor()

# initialize the model with the mask
with torch.cuda.amp.autocast(enabled=True):
image_torch = (
torch.from_numpy(image.transpose(2, 0, 1)).float().to(self.cutie_config.device, non_blocking=True) / 255
)
# initialize with the mask
mask_torch = (
F.one_hot(
torch.from_numpy(mask).long(),
num_classes=len(np.unique(mask)),
)
.permute(2, 0, 1)
.float()
.to(self.cutie_config.device)
)
# the background mask is not fed into the model
self.mask = self.predictor.step(image_torch, mask_torch[1:], idx_mask=False)
self.config = CutieConfig.from_rosparam()
self.model = CutieModel(self.config)
self.model.set_model(image, mask)
return True

def publish_result(self, mask, vis, frame_id):
if mask is not None:
seg_msg = self.bridge.cv2_to_imgmsg(mask.astype(np.int32), encoding="32SC1")
seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1")
seg_msg.header.stamp = rospy.Time.now()
seg_msg.header.frame_id = frame_id
self.pub_segmentation_img.publish(seg_msg)
Expand All @@ -114,38 +88,11 @@ def publish_result(self, mask, vis, frame_id):
vis_img_msg.header.frame_id = frame_id
self.pub_vis_img.publish(vis_img_msg)

@torch.inference_mode()
def callback(self, img_msg):
if self.track_flag:
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
with torch.cuda.amp.autocast(enabled=True):
image_torch = (
torch.from_numpy(self.image.transpose(2, 0, 1))
.float()
.to(self.cutie_config.device, non_blocking=True)
/ 255
)
prediction = self.predictor.step(image_torch)
self.mask = torch.max(prediction, dim=0).indices.cpu().numpy().astype(np.uint8)
self.visualization = overlay_davis(self.image.copy(), self.mask)
if self.with_bbox and len(np.unique(self.mask)) > 1:
masks = []
for i in range(1, len(np.unique(self.mask))):
masks.append((self.mask == i).astype(np.uint8))

self.masks = np.stack(masks, axis=0)
xyxy = sv.mask_to_xyxy(self.masks) # [N, 4]
detections = sv.Detections(
xyxy=xyxy,
mask=self.masks,
tracker_id=np.arange(len(xyxy)),
)
self.visualization = BOX_ANNOTATOR.annotate(
scene=self.visualization,
detections=detections,
labels=[f"ObjectID : {i}" for i in range(len(xyxy))],
)
self.publish_result(self.mask, self.visualization, img_msg.header.frame_id)
image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
segmentation, visualization = self.model.predict(image)
self.publish_result(segmentation, visualization, img_msg.header.frame_id)


if __name__ == "__main__":
Expand Down
147 changes: 19 additions & 128 deletions tracking_ros/node_scripts/deva_node.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import cv2
import torch
import torchvision
import supervision as sv

import numpy as np
import rospy
from cv_bridge import CvBridge
from dynamic_reconfigure.server import Server
Expand All @@ -17,30 +13,25 @@
from jsk_recognition_msgs.msg import ClassificationResult
from jsk_recognition_msgs.msg import Label, LabelArray

from deva.dataset.utils import im_normalization
from deva.inference.object_info import ObjectInfo

from tracking_ros.cfg import GroundingDINOConfig as ServerConfig
from tracking_ros.model_config import SAMConfig, GroundingDINOConfig, DEVAConfig
from tracking_ros.utils import overlay_davis

torch.autograd.set_grad_enabled(False)

BOX_ANNOTATOR = sv.BoxAnnotator()
from tracking_ros.model_wrapper import GroundingDINOModel, SAMModel, DEVAModel


class DevaNode(ConnectionBasedTransport):
def __init__(self):
super(DevaNode, self).__init__()
self.sam_config = SAMConfig.from_rosparam()
self.gd_config = GroundingDINOConfig.from_rosparam()
self.gd_model = GroundingDINOModel(self.gd_config)
self.deva_config = DEVAConfig.from_rosparam()
self.deva_model = DEVAModel(self.deva_config)
self.reconfigure_server = Server(ServerConfig, self.config_cb)
self.with_bbox = rospy.get_param("~with_bbox", True)

self.bridge = CvBridge()
self.pub_segmentation_img = self.advertise("~output/segmentation", Image, queue_size=1)
self.pub_vis_img = self.advertise("~output/segmentation_image", Image, queue_size=1)
self.pub_vis_img = self.advertise("~output/debug_image", Image, queue_size=1)
self.pub_rects = self.advertise("~output/rects", RectArray, queue_size=1)
# self.pub_labels = self.advertise("~output/labels", LabelArray, queue_size=1)
# self.pub_class = self.advertise("~output/class", ClassificationResult, queue_size=1)
Expand All @@ -59,18 +50,21 @@ def unsubscribe(self):

def config_cb(self, config, level):
self.track_flag = False
self.sam_predictor = self.sam_config.get_predictor()
self.gd_predictor = self.gd_config.get_predictor()
self.deva_predictor, self.cfg = self.deva_config.get_predictor() # TODO integrate cfg into DEVAConfig
self.sam_config = SAMConfig.from_rosparam()
self.sam_model = SAMModel(self.sam_config)
self.sam_model.set_model()
self.classes = [_class.strip() for _class in config.classes.split(";")]
self.cfg["prompt"] = ".".join(self.classes)
self.cnt = 0
self.gd_model.set_model(
self.classes, config.box_threshold, config.text_threshold, config.nms_threshold
)
self.deva_model.set_model()
rospy.loginfo(f"Detecting Classes: {self.classes}")
self.track_flag = True
return config

def publish_result(self, mask, vis, frame_id):
if mask is not None:
seg_msg = self.bridge.cv2_to_imgmsg(mask.astype(np.int32), encoding="32SC1")
seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1")
seg_msg.header.stamp = rospy.Time.now()
seg_msg.header.frame_id = frame_id
self.pub_segmentation_img.publish(seg_msg)
Expand All @@ -80,116 +74,13 @@ def publish_result(self, mask, vis, frame_id):
vis_img_msg.header.frame_id = frame_id
self.pub_vis_img.publish(vis_img_msg)

@torch.inference_mode()
def callback(self, img_msg):
if self.track_flag:
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
with torch.cuda.amp.autocast(enabled=self.cfg["amp"]):
torch_image = im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255)
deva_input = torch_image.to(self.deva_config.device)
if self.cnt % self.cfg["detection_every"] == 0: # object detection query
self.sam_predictor.set_image(self.image, image_format="RGB")
# detect objects with GroundingDINO
detections = self.gd_predictor.predict_with_classes(
image=cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR),
classes=self.classes,
box_threshold=self.cfg["DINO_THRESHOLD"],
text_threshold=self.cfg["DINO_THRESHOLD"],
)
nms_idx = (
torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
self.cfg["DINO_NMS_THRESHOLD"],
)
.numpy()
.tolist()
)
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
# segment objects with SAM
result_masks = []
for box in detections.xyxy:
masks, scores, _ = self.sam_predictor.predict(box=box, multimask_output=True)
index = np.argmax(scores)
result_masks.append(masks[index])
detections.mask = np.array(result_masks)
incorporate_mask = torch.zeros(
self.image.shape[:2], dtype=torch.int64, device=self.gd_predictor.device
)
curr_id = 1
segments_info = []
# sort by descending area to preserve the smallest object
for i in np.flip(np.argsort(detections.area)):
mask = detections.mask[i]
confidence = detections.confidence[i]
class_id = detections.class_id[i]
mask = torch.from_numpy(mask.astype(np.float32))
mask = (mask > 0.5).float()
if mask.sum() > 0:
incorporate_mask[mask > 0] = curr_id
segments_info.append(ObjectInfo(id=curr_id, category_id=class_id, score=confidence))
curr_id += 1
prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info)
self.cnt = 1
else:
prob = self.deva_predictor.step(deva_input, None, None)
self.cnt += 1
self.mask = torch.argmax(prob, dim=0).cpu().numpy() # (H, W)
object_num = len(np.unique(self.mask)) - 1
if self.with_bbox and object_num > 0:
masks = []
for i in np.unique(self.mask)[1:]:
mask = (self.mask == i).astype(np.uint8)
masks.append(mask)
self.masks = np.stack(masks, axis=0) # (N, H, W)
xyxy = sv.mask_to_xyxy(self.masks)
object_ids = np.unique(self.mask)[1:] # without background
detections = sv.Detections(
xyxy=xyxy,
mask=self.masks,
class_id=object_ids,
)
painted_image = overlay_davis(self.image.copy(), self.mask)
# TODO convert labels to class name, but it needs some trick because object id and class id is not consistent between tracking and detecting
self.visualization = BOX_ANNOTATOR.annotate(
scene=painted_image,
detections=detections,
labels=[f"ObjectID: {obj_id}" for obj_id in object_ids],
)

rects = []
for box in xyxy:
rect = Rect()
rect.x = int((box[0] + box[2]) / 2)
rect.y = int((box[1] + box[3]) / 2)
rect.width = int(box[2] - box[0])
rect.height = int(box[3] - box[1])
rects.append(rect)
rect_array = RectArray(rects=rects)
rect_array.header.stamp = rospy.Time.now()
rect_array.header.frame_id = img_msg.header.frame_id
self.pub_rects.publish(rect_array)

# label_names = [self.classes[cls_id] for cls_id in detections.class_id]
# label_array = LabelArray()
# label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)]
# label_array.header.stamp = rospy.Time.now()
# label_array.header.frame_id = img_msg.header.frame_id
# class_result = ClassificationResult(
# header=label_array.header,
# classifier=self.gd_config.model_name,
# target_names=self.classes,
# labels=[self.classes.index(name) for name in label_names],
# label_names=label_names,
# label_proba=detections.confidence.tolist(),
# )
# self.pub_labels.publish(label_array)
# self.pub_class.publish(class_result)
else:
self.visualization = overlay_davis(self.image.copy(), self.mask)
self.publish_result(self.mask, self.visualization, img_msg.header.frame_id)
image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
detections, visualization, segmentation = self.deva_model.predict(
image, self.sam_model, self.gd_model
)
self.publish_result(segmentation, visualization, img_msg.header.frame_id)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f87758c

Please sign in to comment.