Skip to content

Commit

Permalink
fix for grounding dino when using dynamic reconfigure
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Mar 12, 2024
1 parent 05650b5 commit 4e2badf
Showing 1 changed file with 56 additions and 52 deletions.
108 changes: 56 additions & 52 deletions tracking_ros/node_scripts/grounding_dino_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ def unsubscribe(self):
self.sub_image.unregister()

def config_cb(self, config, level):
self.classes = [_class.strip() for _class in config.classes.split(";")]
self.detect_flag = False
self.classes = [_class.strip() for _class in config.classes.split(";") if _class.strip()]
rospy.loginfo(f"Detecting Classes: {self.classes}")
self.box_threshold = config.box_threshold
self.text_threshold = config.text_threshold
self.nms_threshold = config.nms_threshold
self.detect_flag = True
return config

def publish_result(self, boxes, label_names, scores, mask, vis, frame_id):
Expand Down Expand Up @@ -102,60 +105,61 @@ def publish_result(self, boxes, label_names, scores, mask, vis, frame_id):
self.pub_seg.publish(seg_msg)

def callback(self, img_msg):
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
detections = self.predictor.predict_with_classes(
image=self.image,
classes=self.classes,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
if self.detect_flag:
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
detections = self.predictor.predict_with_classes(
image=self.image,
classes=self.classes,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)

nms_idx = (
torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
self.nms_threshold,
nms_idx = (
torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
self.nms_threshold,
)
.numpy()
.tolist()
)
.numpy()
.tolist()
)

detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]

labels = [self.classes[cls_id] for cls_id in detections.class_id]
scores = detections.confidence.tolist()
labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)]

self.visualization = self.image.copy()
self.segmentation = None
if self.get_mask and len(detections.xyxy) > 0:
rospy.wait_for_service("/sam_node/process_prompt")
try:
prompt = SamPromptRequest()
prompt.image = img_msg
for xyxy in detections.xyxy:
rect = Rect()
rect.x = int(xyxy[0]) # x1
rect.y = int(xyxy[1]) # y1
rect.width = int(xyxy[2] - xyxy[0]) # x2 - x1
rect.height = int(xyxy[3] - xyxy[1]) # y2 - y1
prompt.boxes.append(rect)
prompt_response = rospy.ServiceProxy("/sam_node/process_prompt", SamPrompt)
res = prompt_response(prompt)
seg_msg, vis_img_msg = res.segmentation, res.segmentation_image
self.segmentation = self.bridge.imgmsg_to_cv2(seg_msg, desired_encoding="32SC1")
self.visualization = self.bridge.imgmsg_to_cv2(vis_img_msg, desired_encoding="rgb8")
except rospy.ServiceException as e:
rospy.logerr(f"Service call failed: {e}")
self.visualization = BOX_ANNOTATOR.annotate(scene=self.visualization, detections=detections)
self.visualization = LABEL_ANNOTATOR.annotate(
scene=self.visualization, detections=detections, labels=labels_with_scores
)
self.publish_result(
detections.xyxy, labels, scores, self.segmentation, self.visualization, img_msg.header.frame_id
)
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]

labels = [self.classes[cls_id] for cls_id in detections.class_id]
scores = detections.confidence.tolist()
labels_with_scores = [f"{label} {score:.2f}" for label, score in zip(labels, scores)]

self.visualization = self.image.copy()
self.segmentation = None
if self.get_mask and len(detections.xyxy) > 0:
rospy.wait_for_service("/sam_node/process_prompt")
try:
prompt = SamPromptRequest()
prompt.image = img_msg
for xyxy in detections.xyxy:
rect = Rect()
rect.x = int(xyxy[0]) # x1
rect.y = int(xyxy[1]) # y1
rect.width = int(xyxy[2] - xyxy[0]) # x2 - x1
rect.height = int(xyxy[3] - xyxy[1]) # y2 - y1
prompt.boxes.append(rect)
prompt_response = rospy.ServiceProxy("/sam_node/process_prompt", SamPrompt)
res = prompt_response(prompt)
seg_msg, vis_img_msg = res.segmentation, res.segmentation_image
self.segmentation = self.bridge.imgmsg_to_cv2(seg_msg, desired_encoding="32SC1")
self.visualization = self.bridge.imgmsg_to_cv2(vis_img_msg, desired_encoding="rgb8")
except rospy.ServiceException as e:
rospy.logerr(f"Service call failed: {e}")
self.visualization = BOX_ANNOTATOR.annotate(scene=self.visualization, detections=detections)
self.visualization = LABEL_ANNOTATOR.annotate(
scene=self.visualization, detections=detections, labels=labels_with_scores
)
self.publish_result(
detections.xyxy, labels, scores, self.segmentation, self.visualization, img_msg.header.frame_id
)


if __name__ == "__main__":
Expand Down

0 comments on commit 4e2badf

Please sign in to comment.