Skip to content

Commit

Permalink
tweak for cutie track with prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Apr 14, 2024
1 parent b9e6243 commit d1f3309
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
50 changes: 40 additions & 10 deletions tracking_ros/node_scripts/cutie_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_srvs.srv import Empty, EmptyResponse
from tracking_ros_utils.srv import CutiePrompt, CutiePromptResponse

from model_config import CutieConfig
from utils import overlay_davis
Expand All @@ -22,7 +23,8 @@ def __init__(self):
super(CutieNode, self).__init__()
self.with_bbox = rospy.get_param("~with_bbox", False)
self.bridge = CvBridge()
self.initialize()
image, mask = self.get_oneshot_prompt()
self.track_flag = self.initialize(image, mask)

self.sub_image = rospy.Subscriber(
"~input_image",
Expand All @@ -37,21 +39,49 @@ def __init__(self):
# reset tracking service
self.reset_service = rospy.Service("~reset", Empty, self.reset_callback)

self.process_prompt_service = rospy.Service(
"~process_prompt",
CutiePrompt,
self.prompt_service_callback,
)

def get_oneshot_prompt(self):
try:
# oneshot subscribe initial image and segmentation
input_seg_msg = rospy.wait_for_message("~input_segmentation", Image, timeout=5)
input_img_msg = rospy.wait_for_message("~input_image", Image, timeout=5)
mask = self.bridge.imgmsg_to_cv2(input_seg_msg, desired_encoding="32SC1")
image = self.bridge.imgmsg_to_cv2(input_img_msg, desired_encoding="rgb8")
return image, mask
except rospy.ROSException:
rospy.logwarn("No message received in 5 seconds")
return None, None

def prompt_service_callback(self, req):
rospy.loginfo("Processing prompt and resetting Cutie tracker")
self.track_flag = False
input_seg_msg = req.segmentation
input_img_msg = req.image
mask = self.bridge.imgmsg_to_cv2(input_seg_msg, desired_encoding="32SC1")
image = self.bridge.imgmsg_to_cv2(input_img_msg, desired_encoding="rgb8")
self.track_flag = self.initialize(image, mask)
return CutiePromptResponse(result=True)

def reset_callback(self, req):
rospy.loginfo("Resetting Cutie tracker")
self.initialize()
self.track_flag = False
image, mask = self.get_oneshot_prompt()
self.initialize(image, mask)
self.track_flag = self.initialize(image, mask)
return EmptyResponse()

@torch.inference_mode()
def initialize(self):
self.track_flag = False
def initialize(self, image, mask):
if image is None or mask is None:
return False

self.cutie_config = CutieConfig.from_rosparam()
self.predictor = self.cutie_config.get_predictor()
# oneshot subscribe initial image and segmentation
input_seg_msg = rospy.wait_for_message("~input_segmentation", Image)
mask = self.bridge.imgmsg_to_cv2(input_seg_msg, desired_encoding="32SC1")
input_img_msg = rospy.wait_for_message("~input_image", Image)
image = self.bridge.imgmsg_to_cv2(input_img_msg, desired_encoding="rgb8")

# initialize the model with the mask
with torch.cuda.amp.autocast(enabled=True):
Expand All @@ -70,7 +100,7 @@ def initialize(self):
)
# the background mask is not fed into the model
self.mask = self.predictor.step(image_torch, mask_torch[1:], idx_mask=False)
self.track_flag = True
return True

def publish_result(self, mask, vis, frame_id):
if mask is not None:
Expand Down
1 change: 1 addition & 0 deletions tracking_ros_utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ find_package(catkin REQUIRED
add_service_files(
FILES
SamPrompt.srv
CutiePrompt.srv
)
generate_messages(
DEPENDENCIES
Expand Down
6 changes: 6 additions & 0 deletions tracking_ros_utils/srv/CutiePrompt.srv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Request
sensor_msgs/Image image # prompt image
sensor_msgs/Image segmentation # prompt segmentation
---
# Response
bool result # tracker set result

0 comments on commit d1f3309

Please sign in to comment.