Skip to content

Commit

Permalink
fix deva node
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jul 29, 2024
1 parent 179faf9 commit 228b31d
Showing 1 changed file with 50 additions and 10 deletions.
60 changes: 50 additions & 10 deletions deep_vision_ros/node_scripts/deva_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def __init__(self):
self.with_bbox = rospy.get_param("~with_bbox", True)

self.bridge = CvBridge()
self.pub_segmentation_img = self.advertise("~output/segmentation", Image, queue_size=1)
self.pub_seg = self.advertise("~output/segmentation", Image, queue_size=1)
self.pub_vis_img = self.advertise("~output/debug_image", Image, queue_size=1)
self.pub_rects = self.advertise("~output/rects", RectArray, queue_size=1)
# self.pub_labels = self.advertise("~output/labels", LabelArray, queue_size=1)
# self.pub_class = self.advertise("~output/class", ClassificationResult, queue_size=1)
self.pub_labels = self.advertise("~output/labels", LabelArray, queue_size=1)
self.pub_class = self.advertise("~output/class", ClassificationResult, queue_size=1)

def subscribe(self):
self.sub_image = rospy.Subscriber(
Expand Down Expand Up @@ -62,25 +62,65 @@ def config_cb(self, config, level):
self.track_flag = True
return config

def publish_result(self, mask, vis, frame_id):
if mask is not None:
seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1")
seg_msg.header.stamp = rospy.Time.now()
seg_msg.header.frame_id = frame_id
self.pub_segmentation_img.publish(seg_msg)

def publish_result(self, boxes, label_names, scores, mask, vis, frame_id):
if label_names is not None:
label_array = LabelArray()
label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)]
label_array.header.stamp = rospy.Time.now()
label_array.header.frame_id = frame_id
self.pub_labels.publish(label_array)

class_result = ClassificationResult(
header=label_array.header,
classifier=self.gd_config.model_name,
target_names=self.classes,
labels=[self.classes.index(name) for name in label_names],
label_names=label_names,
label_proba=scores,
)
self.pub_class.publish(class_result)

if boxes is not None:
rects = []
for box in boxes:
rect = Rect()
rect.x = int(box[0]) # x1
rect.y = int(box[1]) # y1
rect.width = int(box[2] - box[0]) # x2 - x1
rect.height = int(box[3] - box[1]) # y2 - y1
rects.append(rect)
rect_array = RectArray(rects=rects)
rect_array.header.stamp = rospy.Time.now()
rect_array.header.frame_id = frame_id
self.pub_rects.publish(rect_array)

if vis is not None:
vis_img_msg = self.bridge.cv2_to_imgmsg(vis, encoding="rgb8")
vis_img_msg.header.stamp = rospy.Time.now()
vis_img_msg.header.frame_id = frame_id
self.pub_vis_img.publish(vis_img_msg)

if mask is not None:
seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1")
seg_msg.header.stamp = rospy.Time.now()
seg_msg.header.frame_id = frame_id
self.pub_seg.publish(seg_msg)

def callback(self, img_msg):
if self.track_flag:
image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
detections, visualization, segmentation = self.deva_model.predict(
image, self.sam_model, self.gd_model
)
self.publish_result(segmentation, visualization, img_msg.header.frame_id)
self.publish_result(
detections.xyxy,
[self.classes[class_id] for class_id in detections.class_id],
detections.confidence,
segmentation,
visualization,
img_msg.header.frame_id,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 228b31d

Please sign in to comment.