diff --git a/deep_vision_ros/node_scripts/deva_node.py b/deep_vision_ros/node_scripts/deva_node.py index f28b0f8..05ba694 100644 --- a/deep_vision_ros/node_scripts/deva_node.py +++ b/deep_vision_ros/node_scripts/deva_node.py @@ -30,11 +30,11 @@ def __init__(self): self.with_bbox = rospy.get_param("~with_bbox", True) self.bridge = CvBridge() - self.pub_segmentation_img = self.advertise("~output/segmentation", Image, queue_size=1) + self.pub_seg = self.advertise("~output/segmentation", Image, queue_size=1) self.pub_vis_img = self.advertise("~output/debug_image", Image, queue_size=1) self.pub_rects = self.advertise("~output/rects", RectArray, queue_size=1) - # self.pub_labels = self.advertise("~output/labels", LabelArray, queue_size=1) - # self.pub_class = self.advertise("~output/class", ClassificationResult, queue_size=1) + self.pub_labels = self.advertise("~output/labels", LabelArray, queue_size=1) + self.pub_class = self.advertise("~output/class", ClassificationResult, queue_size=1) def subscribe(self): self.sub_image = rospy.Subscriber( @@ -62,25 +62,65 @@ def config_cb(self, config, level): self.track_flag = True return config - def publish_result(self, mask, vis, frame_id): - if mask is not None: - seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1") - seg_msg.header.stamp = rospy.Time.now() - seg_msg.header.frame_id = frame_id - self.pub_segmentation_img.publish(seg_msg) + + def publish_result(self, boxes, label_names, scores, mask, vis, frame_id): + if label_names is not None: + label_array = LabelArray() + label_array.labels = [Label(id=i + 1, name=name) for i, name in enumerate(label_names)] + label_array.header.stamp = rospy.Time.now() + label_array.header.frame_id = frame_id + self.pub_labels.publish(label_array) + + class_result = ClassificationResult( + header=label_array.header, + classifier=self.gd_config.model_name, + target_names=self.classes, + labels=[self.classes.index(name) for name in label_names], + label_names=label_names, + label_proba=scores, + ) + self.pub_class.publish(class_result) + + if boxes is not None: + rects = [] + for box in boxes: + rect = Rect() + rect.x = int(box[0]) # x1 + rect.y = int(box[1]) # y1 + rect.width = int(box[2] - box[0]) # x2 - x1 + rect.height = int(box[3] - box[1]) # y2 - y1 + rects.append(rect) + rect_array = RectArray(rects=rects) + rect_array.header.stamp = rospy.Time.now() + rect_array.header.frame_id = frame_id + self.pub_rects.publish(rect_array) + if vis is not None: vis_img_msg = self.bridge.cv2_to_imgmsg(vis, encoding="rgb8") vis_img_msg.header.stamp = rospy.Time.now() vis_img_msg.header.frame_id = frame_id self.pub_vis_img.publish(vis_img_msg) + if mask is not None: + seg_msg = self.bridge.cv2_to_imgmsg(mask, encoding="32SC1") + seg_msg.header.stamp = rospy.Time.now() + seg_msg.header.frame_id = frame_id + self.pub_seg.publish(seg_msg) + def callback(self, img_msg): if self.track_flag: image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8") detections, visualization, segmentation = self.deva_model.predict( image, self.sam_model, self.gd_model ) - self.publish_result(segmentation, visualization, img_msg.header.frame_id) + self.publish_result( + detections.xyxy, + [self.classes[class_id] for class_id in detections.class_id], + detections.confidence, + segmentation, + visualization, + img_msg.header.frame_id, + ) if __name__ == "__main__":