diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index 37b7f6e9..e98dace2 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -46,6 +46,7 @@ def test_feature_extractor(): # Store results to test directory img = get_img_from_fig(fig) img.save(join(outpath, f"forest_clean_graph_{seg_type}_{feat_type}.png")) + plt.close() if __name__ == "__main__": diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index dc3c62bc..92b58a54 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -102,7 +102,7 @@ class ModelParams: @dataclass class SimpleMlpCfgParams: - input_size: int = 384 # 90 for stego, 384 for dino + input_size: int = 90 # 90 for stego, 384 for dino hidden_sizes: List[int] = field(default_factory=lambda: [256, 32, 1]) reconstruction: bool = True diff --git a/wild_visual_navigation/cfg/ros_params.py b/wild_visual_navigation/cfg/ros_params.py index 5b9b2578..d0734705 100644 --- a/wild_visual_navigation/cfg/ros_params.py +++ b/wild_visual_navigation/cfg/ros_params.py @@ -50,6 +50,7 @@ class RosLearningNodeParams: supervision_callback_rate: float # hertz learning_thread_rate: float # hertz logging_thread_rate: float # hertz + load_save_checkpoint_rate: float # hert # Runtime options device: str @@ -97,3 +98,4 @@ class RosFeatureExtractorNodeParams: # Threads image_callback_rate: float # hertz + load_save_checkpoint_rate: float # hertz diff --git a/wild_visual_navigation/feature_extractor/segment_extractor.py b/wild_visual_navigation/feature_extractor/segment_extractor.py index 1aec152b..5726ed2a 100644 --- a/wild_visual_navigation/feature_extractor/segment_extractor.py +++ b/wild_visual_navigation/feature_extractor/segment_extractor.py @@ -41,7 +41,7 @@ def adjacency_list(self, seg: torch.tensor): Returns: adjacency_list (torch.Tensor, dtype=torch.long, shape=(N, 2): Adjacency list of undirected graph """ - assert seg.shape[0] == 1 and len(seg.shape) == 4 + assert seg.shape[0] == 1 and len(seg.shape) == 4, f"{seg.shape}" res = self.f1(seg.type(torch.float32)) boundary_mask = (res != 0)[0, :, 2:-2, 2:-2] diff --git a/wild_visual_navigation/utils/testing.py b/wild_visual_navigation/utils/testing.py index 9c1a2224..b83af273 100644 --- a/wild_visual_navigation/utils/testing.py +++ b/wild_visual_navigation/utils/testing.py @@ -16,8 +16,8 @@ def load_test_image(): def get_dino_transform(): transform = T.Compose( [ - T.Resize(448, T.InterpolationMode.NEAREST), - T.CenterCrop(448), + T.Resize(224, T.InterpolationMode.NEAREST), + T.CenterCrop(224), ] ) return transform diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 66aa0093..f772eda5 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -21,8 +21,8 @@ image_graph_dist_thr: 0.2 # meters supervision_graph_dist_thr: 0.1 # meters network_input_image_height: 224 # 448 network_input_image_width: 224 # 448 -segmentation_type: "random" # Options: slic, grid, random, stego -feature_type: "dino" # Options: dino, dinov2, stego +segmentation_type: "stego" # Options: slic, grid, random, stego +feature_type: "stego" # Options: dino, dinov2, stego dino_patch_size: 8 # 8 or 16; 8 is finer dino_backbone: vit_small slic_num_components: 100 @@ -46,6 +46,7 @@ supervision_callback_rate: 10 # hertz learning_thread_rate: 10 # hertz logging_thread_rate: 2 # hertz status_thread_rate: 0.5 # hertz +load_save_checkpoint_rate: 0.2 # hertz, 1/0.2 = 5 sec equivalent # Runtime options device: "cuda" diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 0ee77832..98439e71 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -38,6 +38,7 @@ def __init__(self, node_name): # Timers to control the rate of the subscriber self._last_image_ts = rospy.get_time() + self._last_checkpoint_ts = rospy.get_time() # Load model self._model = get_model(self._params.model).to(self._ros_params.device) @@ -280,7 +281,6 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo else: if self._ros_params.verbose: rospy.loginfo(f"[{self._node_name}] Image callback: {cam} -> Process") - self._scheduler.step() self._last_image_ts = ts @@ -292,7 +292,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo self._log_data[f"time_last_image_{cam}"] = rospy.get_time() # Update model from file if possible - self.load_model() + self.load_model(image_msg.header.stamp) # Convert image message to torch image torch_image = rc.ros_image_to_torch(image_msg, device=self._ros_params.device) @@ -414,12 +414,21 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo } raise Exception("Error in image callback") - def load_model(self): + # Step scheduler + self._scheduler.step() + + def load_model(self, stamp): """Method to load the new model weights to perform inference on the incoming images Args: None """ + ts = stamp.to_sec() + if abs(ts - self._last_checkpoint_ts) < 1.0 / self._ros_params.load_save_checkpoint_rate: + return + + self._last_checkpoint_ts = ts + try: # self._load_model_counter += 1 # if self._load_model_counter % 10 == 0: diff --git a/wild_visual_navigation_ros/scripts/wvn_learning_node.py b/wild_visual_navigation_ros/scripts/wvn_learning_node.py index 3fbae134..35d05fe1 100644 --- a/wild_visual_navigation_ros/scripts/wvn_learning_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_learning_node.py @@ -45,8 +45,9 @@ class WvnLearning: def __init__(self, node_name): # Timers to control the rate of the publishers - self.last_image_ts = rospy.get_time() - self.last_supervision_ts = rospy.get_time() + self._last_image_ts = rospy.get_time() + self._last_supervision_ts = rospy.get_time() + self._last_checkpoint_ts = rospy.get_time() # Prepare variables self._node_name = node_name @@ -55,56 +56,56 @@ def __init__(self, node_name): self.read_params() # Visualization - self.color_palette = sns.color_palette(self.ros_params.colormap, as_cmap=True) + self._color_palette = sns.color_palette(self._ros_params.colormap, as_cmap=True) # Setup Mission Folder - model_path = create_experiment_folder(self.params) + model_path = create_experiment_folder(self._params) - with read_write(self.params): - self.params.general.model_path = model_path + with read_write(self._params): + self._params.general.model_path = model_path # Initialize traversability estimator - self.traversability_estimator = TraversabilityEstimator( - params=self.params, - device=self.ros_params.device, - image_size=self.ros_params.network_input_image_height, # Note: we assume height == width - segmentation_type=self.ros_params.segmentation_type, - feature_type=self.ros_params.feature_type, - max_distance=self.ros_params.traversability_radius, - image_distance_thr=self.ros_params.image_graph_dist_thr, - supervision_distance_thr=self.ros_params.supervision_graph_dist_thr, - min_samples_for_training=self.ros_params.min_samples_for_training, - vis_node_index=self.ros_params.vis_node_index, - mode=self.ros_params.mode, - extraction_store_folder=self.ros_params.extraction_store_folder, - scale_traversability=self.ros_params.scale_traversability, + self._traversability_estimator = TraversabilityEstimator( + params=self._params, + device=self._ros_params.device, + image_size=self._ros_params.network_input_image_height, # Note: we assume height == width + segmentation_type=self._ros_params.segmentation_type, + feature_type=self._ros_params.feature_type, + max_distance=self._ros_params.traversability_radius, + image_distance_thr=self._ros_params.image_graph_dist_thr, + supervision_distance_thr=self._ros_params.supervision_graph_dist_thr, + min_samples_for_training=self._ros_params.min_samples_for_training, + vis_node_index=self._ros_params.vis_node_index, + mode=self._ros_params.mode, + extraction_store_folder=self._ros_params.extraction_store_folder, + scale_traversability=self._ros_params.scale_traversability, anomaly_detection=self.anomaly_detection, ) # Initialize traversability generator to process velocity commands - self.supervision_generator = SupervisionGenerator( - self.ros_params.device, + self._supervision_generator = SupervisionGenerator( + self._ros_params.device, kf_process_cov=0.1, kf_meas_cov=10, kf_outlier_rejection="huber", kf_outlier_rejection_delta=0.5, sigmoid_slope=20, sigmoid_cutoff=0.25, # 0.2 - untraversable_thr=self.ros_params.untraversable_thr, # 0.1 + untraversable_thr=self._ros_params.untraversable_thr, # 0.1 time_horizon=0.05, ) # Initialize camera handler for subscription/publishing - self.system_events = {} + self._system_events = {} # Setup ros - self.setup_ros(setup_fully=self.ros_params.mode != WVNMode.EXTRACT_LABELS) + self.setup_ros(setup_fully=self._ros_params.mode != WVNMode.EXTRACT_LABELS) # Setup Timer if needed - self.timer = ClassTimer( + self._timer = ClassTimer( objects=[ self, - self.traversability_estimator, - self.traversability_estimator._visualizer, - self.supervision_generator, + self._traversability_estimator, + self._traversability_estimator._visualizer, + self._supervision_generator, ], names=[ "WVN", @@ -113,9 +114,9 @@ def __init__(self, node_name): "SupervisionGenerator", ], enabled=( - self.ros_params.print_image_callback_time - or self.ros_params.print_supervision_callback_time - or self.ros_params.log_time + self._ros_params.print_image_callback_time + or self._ros_params.print_supervision_callback_time + or self._ros_params.log_time ), ) @@ -127,8 +128,8 @@ def __init__(self, node_name): # Launch processes print("-" * 80) rospy.loginfo(f"[{self._node_name}] Launching [learning] thread") - if self.ros_params.mode != WVNMode.EXTRACT_LABELS: - self.learning_thread_stop_event = Event() + if self._ros_params.mode != WVNMode.EXTRACT_LABELS: + self._learning_thread_stop_event = Event() self.learning_thread = Thread(target=self.learning_thread_loop, name="learning") self.learning_thread.start() @@ -140,22 +141,22 @@ def __init__(self, node_name): def shutdown_callback(self, *args, **kwargs): # Write stuff to files rospy.logwarn("Shutdown callback called") - if self.ros_params.mode != WVNMode.EXTRACT_LABELS: - self.learning_thread_stop_event.set() + if self._ros_params.mode != WVNMode.EXTRACT_LABELS: + self._learning_thread_stop_event.set() # self.logging_thread_stop_event.set() print(f"[{self._node_name}] Storing learned checkpoint...", end="") - self.traversability_estimator.save_checkpoint(self.params.general.model_path, "last_checkpoint.pt") + self._traversability_estimator.save_checkpoint(self._params.general.model_path, "last_checkpoint.pt") print("done") - if self.ros_params.log_time: + if self._ros_params.log_time: print(f"[{self._node_name}] Storing timer data...", end="") - self.timer.store(folder=self.params.general.model_path) + self._timer.store(folder=self._params.general.model_path) print("done") print(f"[{self._node_name}] Joining learning thread...", end="") - if self.ros_params.mode != WVNMode.EXTRACT_LABELS: - self.learning_thread_stop_event.set() + if self._ros_params.mode != WVNMode.EXTRACT_LABELS: + self._learning_thread_stop_event.set() self.learning_thread.join() # self.logging_thread_stop_event.set() @@ -171,25 +172,25 @@ def learning_thread_loop(self): We can only set the rate using rosparam """ # Set rate - rate = rospy.Rate(self.ros_params.learning_thread_rate) + rate = rospy.Rate(self._ros_params.learning_thread_rate) # Learning loop while True: - self.system_events["learning_thread_loop"] = { + self._system_events["learning_thread_loop"] = { "time": rospy.get_time(), "value": "running", } - self.learning_thread_stop_event.wait(timeout=0.01) - if self.learning_thread_stop_event.is_set(): + self._learning_thread_stop_event.wait(timeout=0.01) + if self._learning_thread_stop_event.is_set(): rospy.logwarn("Stopped learning thread") break # Optimize model with ClassContextTimer(parent_obj=self, block_name="training_step_time"): - res = self.traversability_estimator.train() + res = self._traversability_estimator.train() - if self.step != self.traversability_estimator.step: - self.step_time = rospy.get_time() - self.step = self.traversability_estimator.step + if self._step != self._traversability_estimator.step: + self._step_time = rospy.get_time() + self._step = self._traversability_estimator.step # Publish loss system_state = SystemState() @@ -197,24 +198,24 @@ def learning_thread_loop(self): if hasattr(system_state, k): setattr(system_state, k, res[k]) - system_state.pause_learning = self.traversability_estimator.pause_learning - system_state.mode = self.ros_params.mode.value - system_state.step = self.step - self.pub_system_state.publish(system_state) + system_state.pause_learning = self._traversability_estimator.pause_learning + system_state.mode = self._ros_params.mode.value + system_state.step = self._step + self._pub_system_state.publish(system_state) # Get current weights - new_model_state_dict = self.traversability_estimator._model.state_dict() + new_model_state_dict = self._traversability_estimator._model.state_dict() # Compute ROC Threshold - if self.ros_params.scale_traversability: - if self.traversability_estimator._auxiliary_training_roc._update_count != 0: + if self._ros_params.scale_traversability: + if self._traversability_estimator._auxiliary_training_roc._update_count != 0: try: ( fpr, tpr, thresholds, - ) = self.traversability_estimator._auxiliary_training_roc.compute() - index = torch.where(fpr > self.ros_params.scale_traversability_max_fpr)[0][0] + ) = self._traversability_estimator._auxiliary_training_roc.compute() + index = torch.where(fpr > self._ros_params.scale_traversability_max_fpr)[0][0] traversability_threshold = thresholds[index] except Exception: traversability_threshold = 0.5 @@ -222,31 +223,35 @@ def learning_thread_loop(self): traversability_threshold = 0.5 new_model_state_dict["traversability_threshold"] = traversability_threshold - cg = self.traversability_estimator._traversability_loss._confidence_generator + cg = self._traversability_estimator._traversability_loss._confidence_generator new_model_state_dict["confidence_generator"] = cg.get_dict() - os.remove(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") - torch.save(new_model_state_dict, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + # Check the rate + ts = rospy.get_time() + if abs(ts - self._last_checkpoint_ts) > 1.0 / self._ros_params.load_save_checkpoint_rate: + os.remove(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + torch.save(new_model_state_dict, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + self._last_checkpoint_ts = ts rate.sleep() - self.system_events["learning_thread_loop"] = { + self._system_events["learning_thread_loop"] = { "time": rospy.get_time(), "value": "finished", } - self.learning_thread_stop_event.clear() + self._learning_thread_stop_event.clear() def logging_thread_loop(self): - rate = rospy.Rate(self.ros_params.logging_thread_rate) + rate = rospy.Rate(self._ros_params.logging_thread_rate) # Learning loop while True: - self.learning_thread_stop_event.wait(timeout=0.01) - if self.learning_thread_stop_event.is_set(): + self._learning_thread_stop_event.wait(timeout=0.01) + if self._learning_thread_stop_event.is_set(): rospy.logwarn("Stopped logging thread") break current_time = rospy.get_time() - tmp = self.system_events.copy() + tmp = self._system_events.copy() rospy.loginfo("System Events:") for k, v in tmp.items(): value = v["value"] @@ -258,56 +263,56 @@ def logging_thread_loop(self): rospy.loginfo(msg) rate.sleep() rospy.loginfo("--------------") - self.learning_thread_stop_event.clear() + self._learning_thread_stop_event.clear() @accumulate_time def read_params(self): """Reads all the parameters from the parameter server""" - self.params = OmegaConf.structured(ExperimentParams) - self.ros_params = OmegaConf.structured(RosLearningNodeParams) + self._params = OmegaConf.structured(ExperimentParams) + self._ros_params = OmegaConf.structured(RosLearningNodeParams) # Override the empty dataclass with values from ros parmeter server - with read_write(self.ros_params): - for k in self.ros_params.keys(): - self.ros_params[k] = rospy.get_param(f"~{k}") + with read_write(self._ros_params): + for k in self._ros_params.keys(): + self._ros_params[k] = rospy.get_param(f"~{k}") - self.ros_params.robot_height = rospy.get_param("~robot_height") # TODO robot_height currently not used + self._ros_params.robot_height = rospy.get_param("~robot_height") # TODO robot_height currently not used - with read_write(self.ros_params): - self.ros_params.mode = WVNMode.from_string(self.ros_params.mode) + with read_write(self._ros_params): + self._ros_params.mode = WVNMode.from_string(self._ros_params.mode) - with read_write(self.params): - self.params.general.name = self.ros_params.mission_name - self.params.general.timestamp = self.ros_params.mission_timestamp - self.params.general.log_confidence = self.ros_params.log_confidence - self.params.loss.confidence_std_factor = self.ros_params.confidence_std_factor - self.params.loss.w_temp = 0 + with read_write(self._params): + self._params.general.name = self._ros_params.mission_name + self._params.general.timestamp = self._ros_params.mission_timestamp + self._params.general.log_confidence = self._ros_params.log_confidence + self._params.loss.confidence_std_factor = self._ros_params.confidence_std_factor + self._params.loss.w_temp = 0 # Parse operation modes - if self.ros_params.mode == WVNMode.ONLINE: + if self._ros_params.mode == WVNMode.ONLINE: rospy.logwarn( f"[{self._node_name}] WARNING: online_mode enabled. The graph will not store any debug/training data such as images\n" ) - elif self.ros_params.mode == WVNMode.EXTRACT_LABELS: - with read_write(self.ros_params): + elif self._ros_params.mode == WVNMode.EXTRACT_LABELS: + with read_write(self._ros_params): # TODO verify if this is needed - self.ros_params.image_callback_rate = 3 - self.ros_params.supervision_callback_rate = 4 - self.ros_params.image_graph_dist_thr = 0.2 - self.ros_params.supervision_graph_dist_thr = 0.1 + self._ros_params.image_callback_rate = 3 + self._ros_params.supervision_callback_rate = 4 + self._ros_params.image_graph_dist_thr = 0.2 + self._ros_params.supervision_graph_dist_thr = 0.1 os.makedirs( - os.path.join(self.ros_params.extraction_store_folder, "image"), + os.path.join(self._ros_params.extraction_store_folder, "image"), exist_ok=True, ) os.makedirs( - os.path.join(self.ros_params.extraction_store_folder, "supervision_mask"), + os.path.join(self._ros_params.extraction_store_folder, "supervision_mask"), exist_ok=True, ) - self.step = -1 - self.step_time = rospy.get_time() - self.anomaly_detection = self.params.model.name == "LinearRnvp" + self._step = -1 + self._step_time = rospy.get_time() + self.anomaly_detection = self._params.model.name == "LinearRnvp" def setup_ros(self, setup_fully=True): """Main function to setup ROS-related stuff: publishers, subscribers and services""" @@ -317,35 +322,35 @@ def setup_ros(self, setup_fully=True): self.tf_listener = tf2_ros.TransformListener(self.tf_buffer) # Robot state callback - robot_state_sub = message_filters.Subscriber(self.ros_params.robot_state_topic, RobotState) + robot_state_sub = message_filters.Subscriber(self._ros_params.robot_state_topic, RobotState) cache1 = message_filters.Cache(robot_state_sub, 10) # noqa: F841 - desired_twist_sub = message_filters.Subscriber(self.ros_params.desired_twist_topic, TwistStamped) + desired_twist_sub = message_filters.Subscriber(self._ros_params.desired_twist_topic, TwistStamped) cache2 = message_filters.Cache(desired_twist_sub, 10) # noqa: F841 - self.robot_state_sub = message_filters.ApproximateTimeSynchronizer( + self._robot_state_sub = message_filters.ApproximateTimeSynchronizer( [robot_state_sub, desired_twist_sub], queue_size=10, slop=0.5 ) rospy.loginfo( - f"[{self._node_name}] Start waiting for RobotState topic {self.ros_params.robot_state_topic} being published!" + f"[{self._node_name}] Start waiting for RobotState topic {self._ros_params.robot_state_topic} being published!" ) - rospy.wait_for_message(self.ros_params.robot_state_topic, RobotState) + rospy.wait_for_message(self._ros_params.robot_state_topic, RobotState) rospy.loginfo( - f"[{self._node_name}] Start waiting for TwistStamped topic {self.ros_params.desired_twist_topic} being published!" + f"[{self._node_name}] Start waiting for TwistStamped topic {self._ros_params.desired_twist_topic} being published!" ) - rospy.wait_for_message(self.ros_params.desired_twist_topic, TwistStamped) - self.robot_state_sub.registerCallback(self.robot_state_callback) + rospy.wait_for_message(self._ros_params.desired_twist_topic, TwistStamped) + self._robot_state_sub.registerCallback(self.robot_state_callback) - self.camera_handler = {} + self._camera_handler = {} # Image callback - for cam in self.ros_params.camera_topics: + for cam in self._ros_params.camera_topics: # Initialize camera handler for given cam - self.camera_handler[cam] = {} + self._camera_handler[cam] = {} # Store camera name - self.ros_params.camera_topics[cam]["name"] = cam + self._ros_params.camera_topics[cam]["name"] = cam # Set subscribers - if self.ros_params.mode == WVNMode.DEBUG: + if self._ros_params.mode == WVNMode.DEBUG: # In debug mode additionally send the image to the callback function self._visualizer = LearningVisualizer() @@ -357,7 +362,7 @@ def setup_ros(self, setup_fully=True): sync = message_filters.ApproximateTimeSynchronizer( [imagefeat_sub, info_sub, image_sub], queue_size=4, slop=0.5 ) - sync.registerCallback(self.imagefeat_callback, self.ros_params.camera_topics[cam]) + sync.registerCallback(self.imagefeat_callback, self._ros_params.camera_topics[cam]) last_image_overlay_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/debug/last_node_image_overlay", @@ -365,8 +370,8 @@ def setup_ros(self, setup_fully=True): queue_size=10, ) - self.camera_handler[cam]["debug"] = {} - self.camera_handler[cam]["debug"]["image_overlay"] = last_image_overlay_pub + self._camera_handler[cam]["debug"] = {} + self._camera_handler[cam]["debug"]["image_overlay"] = last_image_overlay_pub else: imagefeat_sub = message_filters.Subscriber( @@ -376,36 +381,38 @@ def setup_ros(self, setup_fully=True): sync = message_filters.ApproximateTimeSynchronizer( [imagefeat_sub, info_sub], queue_size=4, slop=0.5 ) - sync.registerCallback(self.imagefeat_callback, self.ros_params.camera_topics[cam]) + sync.registerCallback(self.imagefeat_callback, self._ros_params.camera_topics[cam]) # 3D outputs - self.pub_debug_supervision_graph = rospy.Publisher( + self._pub_debug_supervision_graph = rospy.Publisher( "/wild_visual_navigation_node/supervision_graph", Path, queue_size=10 ) - self.pub_mission_graph = rospy.Publisher("/wild_visual_navigation_node/mission_graph", Path, queue_size=10) - self.pub_graph_footprints = rospy.Publisher( + self._pub_mission_graph = rospy.Publisher("/wild_visual_navigation_node/mission_graph", Path, queue_size=10) + self._pub_graph_footprints = rospy.Publisher( "/wild_visual_navigation_node/graph_footprints", Marker, queue_size=10 ) # 1D outputs - self.pub_instant_traversability = rospy.Publisher( + self._pub_instant_traversability = rospy.Publisher( "/wild_visual_navigation_node/instant_traversability", Float32, queue_size=10, ) - self.pub_system_state = rospy.Publisher("/wild_visual_navigation_node/system_state", SystemState, queue_size=10) + self._pub_system_state = rospy.Publisher( + "/wild_visual_navigation_node/system_state", SystemState, queue_size=10 + ) # Services # Like, reset graph or the like - self.save_checkpt_service = rospy.Service("~save_checkpoint", SaveCheckpoint, self.save_checkpoint_callback) - self.load_checkpt_service = rospy.Service("~load_checkpoint", LoadCheckpoint, self.load_checkpoint_callback) + self._save_checkpt_service = rospy.Service("~save_checkpoint", SaveCheckpoint, self.save_checkpoint_callback) + self._load_checkpt_service = rospy.Service("~load_checkpoint", LoadCheckpoint, self.load_checkpoint_callback) - self.pause_learning_service = rospy.Service("~pause_learning", SetBool, self.pause_learning_callback) - self.reset_service = rospy.Service("~reset", Trigger, self.reset_callback) + self._pause_learning_service = rospy.Service("~pause_learning", SetBool, self.pause_learning_callback) + self._reset_service = rospy.Service("~reset", Trigger, self.reset_callback) def pause_learning_callback(self, req): """Start and stop the network training""" - prev_state = self.traversability_estimator.pause_learning - self.traversability_estimator.pause_learning = req.data + prev_state = self._traversability_estimator.pause_learning + self._traversability_estimator.pause_learning = req.data if not req.data and prev_state: message = "Resume training!" elif req.data and prev_state: @@ -414,7 +421,7 @@ def pause_learning_callback(self, req): message = "Training was already running!" elif req.data and not prev_state: message = "Pause training!" - message += f" Updated the network for {self.traversability_estimator.step} steps" + message += f" Updated the network for {self._traversability_estimator.step} steps" return True, message @@ -423,19 +430,19 @@ def reset_callback(self, req): rospy.logwarn(f"[{self._node_name}] System reset!") print(f"[{self._node_name}] Storing learned checkpoint...", end="") - self.traversability_estimator.save_checkpoint(self.params.general.model_path, "last_checkpoint.pt") + self._traversability_estimator.save_checkpoint(self._params.general.model_path, "last_checkpoint.pt") print("done") - if self.ros_params.log_time: + if self._ros_params.log_time: print(f"[{self._node_name}] Storing timer data...", end="") - self.timer.store(folder=self.params.general.model_path) + self._timer.store(folder=self._params.general.model_path) print("done") # Create new mission folder - create_experiment_folder(self.params) + create_experiment_folder(self._params) # Reset traversability estimator - self.traversability_estimator.reset() + self._traversability_estimator.reset() print(f"[{self._node_name}] Reset done") return TriggerResponse(True, "Reset done!") @@ -451,12 +458,12 @@ def save_checkpoint_callback(self, req): req.checkpoint_name = "last_checkpoint.pt" if req.mission_path == "": - message = f"[WARNING] Store checkpoint {req.checkpoint_name} default mission path: {self.params.general.model_path}/{req.checkpoint_name}" - req.mission_path = self.params.general.model_path + message = f"[WARNING] Store checkpoint {req.checkpoint_name} default mission path: {self._params.general.model_path}/{req.checkpoint_name}" + req.mission_path = self._params.general.model_path else: message = f"Store checkpoint {req.checkpoint_name} to: {req.mission_path}/{req.checkpoint_name}" - self.traversability_estimator.save_checkpoint(req.mission_path, req.checkpoint_name) + self._traversability_estimator.save_checkpoint(req.mission_path, req.checkpoint_name) return SaveCheckpointResponse(success=True, message=message) def load_checkpoint_callback(self, req): @@ -471,7 +478,7 @@ def load_checkpoint_callback(self, req): message=f"Path [{req.checkpoint_path}] is empty. Please check and try again", ) checkpoint_path = req.checkpoint_path - self.traversability_estimator.load_checkpoint(checkpoint_path) + self._traversability_estimator.load_checkpoint(checkpoint_path) return LoadCheckpointResponse(success=True, message=f"Checkpoint [{checkpoint_path}] loaded successfully") @accumulate_time @@ -504,7 +511,7 @@ def query_tf(self, parent_frame: str, child_frame: str, stamp: Optional[rospy.Ti rot /= np.linalg.norm(rot) return (trans, tuple(rot)) except Exception: - if self.ros_params.verbose: + if self._ros_params.verbose: # print("Error in query tf: ", e) rospy.logwarn(f"[{self._node_name}] Couldn't get between {parent_frame} and {child_frame}") return (None, None) @@ -517,31 +524,31 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): state_msg (wild_visual_navigation_msgs/RobotState): Robot state message desired_twist_msg (geometry_msgs/TwistStamped): Desired twist message """ - self.system_events["robot_state_callback_received"] = { + self._system_events["robot_state_callback_received"] = { "time": rospy.get_time(), "value": "message received", } try: ts = state_msg.header.stamp.to_sec() - if abs(ts - self.last_supervision_ts) < 1.0 / self.ros_params.supervision_callback_rate: - self.system_events["robot_state_callback_canceled"] = { + if abs(ts - self._last_supervision_ts) < 1.0 / self._ros_params.supervision_callback_rate: + self._system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to rate", } return - self.last_propio_ts = ts + self._last_supervision_ts = ts # Query transforms from TF success, pose_base_in_world = rc.ros_tf_to_torch( self.query_tf( - self.ros_params.fixed_frame, - self.ros_params.base_frame, + self._ros_params.fixed_frame, + self._ros_params.base_frame, state_msg.header.stamp, ), - device=self.ros_params.device, + device=self._ros_params.device, ) if not success: - self.system_events["robot_state_callback_canceled"] = { + self._system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to pose_base_in_world", } @@ -549,35 +556,35 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): success, pose_footprint_in_base = rc.ros_tf_to_torch( self.query_tf( - self.ros_params.base_frame, - self.ros_params.footprint_frame, + self._ros_params.base_frame, + self._ros_params.footprint_frame, state_msg.header.stamp, ), - device=self.ros_params.device, + device=self._ros_params.device, ) if not success: - self.system_events["robot_state_callback_canceled"] = { + self._system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to pose_footprint_in_base", } return # The footprint requires a correction: we use the same orientation as the base - pose_footprint_in_base[:3, :3] = torch.eye(3, device=self.ros_params.device) + pose_footprint_in_base[:3, :3] = torch.eye(3, device=self._ros_params.device) # Convert state to tensor supervision_tensor, supervision_labels = rc.wvn_robot_state_to_torch( - state_msg, device=self.ros_params.device + state_msg, device=self._ros_params.device ) - current_twist_tensor = rc.twist_stamped_to_torch(state_msg.twist, device=self.ros_params.device) - desired_twist_tensor = rc.twist_stamped_to_torch(desired_twist_msg, device=self.ros_params.device) + current_twist_tensor = rc.twist_stamped_to_torch(state_msg.twist, device=self._ros_params.device) + desired_twist_tensor = rc.twist_stamped_to_torch(desired_twist_msg, device=self._ros_params.device) # Update traversability ( traversability, traversability_var, is_untraversable, - ) = self.supervision_generator.update_velocity_tracking( + ) = self._supervision_generator.update_velocity_tracking( current_twist_tensor, desired_twist_tensor, velocities=["vx", "vy"] ) @@ -588,9 +595,9 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): pose_footprint_in_base=pose_footprint_in_base, twist_in_base=current_twist_tensor, desired_twist_in_base=desired_twist_tensor, - width=self.ros_params.robot_width, - length=self.ros_params.robot_length, - height=self.ros_params.robot_height, + width=self._ros_params.robot_width, + length=self._ros_params.robot_length, + height=self._ros_params.robot_height, supervision=supervision_tensor, traversability=traversability, traversability_var=traversability_var, @@ -598,15 +605,15 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): ) # Add node to the graph - self.traversability_estimator.add_supervision_node(supervision_node) + self._traversability_estimator.add_supervision_node(supervision_node) - if self.ros_params.mode == WVNMode.DEBUG or self.ros_params.mode == WVNMode.ONLINE: + if self._ros_params.mode == WVNMode.DEBUG or self._ros_params.mode == WVNMode.ONLINE: self.visualize_supervision() - if self.ros_params.print_supervision_callback_time: - print(f"[{self._node_name}]\n{self.timer}") + if self._ros_params.print_supervision_callback_time: + print(f"[{self._node_name}]\n{self._timer}") - self.system_events["robot_state_callback_state"] = { + self._system_events["robot_state_callback_state"] = { "time": rospy.get_time(), "value": "executed successfully", } @@ -614,7 +621,7 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): except Exception as e: traceback.print_exc() rospy.logerr(f"[{self._node_name}] error state callback", e) - self.system_events["robot_state_callback_state"] = { + self._system_events["robot_state_callback_state"] = { "time": rospy.get_time(), "value": f"failed to execute {e}", } @@ -629,44 +636,44 @@ def imagefeat_callback(self, *args): imagefeat_msg (wild_visual_navigation_msg/ImageFeatures): Incoming imagefeatures info_msg (sensor_msgs/CameraInfo): Camera info message associated to the image """ - if self.ros_params.mode == WVNMode.DEBUG: + if self._ros_params.mode == WVNMode.DEBUG: assert len(args) == 4 imagefeat_msg, info_msg, image_msg, camera_options = tuple(args) else: assert len(args) == 3 imagefeat_msg, info_msg, camera_options = tuple(args) - self.system_events["image_callback_received"] = { + self._system_events["image_callback_received"] = { "time": rospy.get_time(), "value": "message received", } - if self.ros_params.verbose: + if self._ros_params.verbose: print(f"[{self._node_name}] Image callback: {camera_options['name']}... ", end="") try: # Run the callback so as to match the desired rate ts = imagefeat_msg.header.stamp.to_sec() - if abs(ts - self.last_image_ts) < 1.0 / self.ros_params.image_callback_rate: - if self.ros_params.verbose: + if abs(ts - self._last_image_ts) < 1.0 / self._ros_params.image_callback_rate: + if self._ros_params.verbose: print(f"skip") return else: - if self.ros_params.verbose: + if self._ros_params.verbose: print(f"process") - self.last_image_ts = ts + self._last_image_ts = ts # Query transforms from TF success, pose_base_in_world = rc.ros_tf_to_torch( self.query_tf( - self.ros_params.fixed_frame, - self.ros_params.base_frame, + self._ros_params.fixed_frame, + self._ros_params.base_frame, imagefeat_msg.header.stamp, ), - device=self.ros_params.device, + device=self._ros_params.device, ) if not success: - self.system_events["image_callback_canceled"] = { + self._system_events["image_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to pose_base_in_world", } @@ -674,49 +681,49 @@ def imagefeat_callback(self, *args): success, pose_cam_in_base = rc.ros_tf_to_torch( self.query_tf( - self.ros_params.base_frame, + self._ros_params.base_frame, imagefeat_msg.header.frame_id, imagefeat_msg.header.stamp, ), - device=self.ros_params.device, + device=self._ros_params.device, ) if not success: - self.system_events["image_callback_canceled"] = { + self._system_events["image_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to pose_cam_in_base", } return # Prepare image projector - K, H, W = rc.ros_cam_info_to_tensors(info_msg, device=self.ros_params.device) + K, H, W = rc.ros_cam_info_to_tensors(info_msg, device=self._ros_params.device) image_projector = ImageProjector( K=K, h=H, w=W, - new_h=self.ros_params.network_input_image_height, - new_w=self.ros_params.network_input_image_width, + new_h=self._ros_params.network_input_image_height, + new_w=self._ros_params.network_input_image_width, ) # Add image to base node # convert image message to torch image feature_segments = rc.ros_image_to_torch( imagefeat_msg.feature_segments, desired_encoding="passthrough", - device=self.ros_params.device, + device=self._ros_params.device, ).clone() h_small, w_small = feature_segments.shape[1:3] torch_image = torch.zeros( (3, h_small, w_small), - device=self.ros_params.device, + device=self._ros_params.device, dtype=torch.float32, ) # convert image message to torch image - if self.ros_params.mode == WVNMode.DEBUG: + if self._ros_params.mode == WVNMode.DEBUG: torch_image = rc.ros_image_to_torch( image_msg, desired_encoding="passthrough", - device=self.ros_params.device, + device=self._ros_params.device, ).clone() # Create mission node for the graph @@ -733,26 +740,26 @@ def imagefeat_callback(self, *args): dims = tuple(map(lambda x: x.size, ma.layout.dim)) mission_node.features = torch.from_numpy( np.array(ma.data, dtype=float).reshape(dims).astype(np.float32) - ).to(self.ros_params.device) + ).to(self._ros_params.device) mission_node.feature_segments = feature_segments[0] # Add node to graph - added_new_node = self.traversability_estimator.add_mission_node(mission_node, update_features=False) + added_new_node = self._traversability_estimator.add_mission_node(mission_node, update_features=False) - if self.ros_params.mode == WVNMode.DEBUG: + if self._ros_params.mode == WVNMode.DEBUG: # Publish current predictions self.visualize_mission_graph() # Publish supervision data depending on the mode self.visualize_image_overlay() if added_new_node: - self.traversability_estimator.update_visualization_node() + self._traversability_estimator.update_visualization_node() # Print callback time if required - if self.ros_params.print_image_callback_time: - rospy.loginfo(f"[{self._node_name}]\n{self.timer}") + if self._ros_params.print_image_callback_time: + rospy.loginfo(f"[{self._node_name}]\n{self._timer}") - self.system_events["image_callback_state"] = { + self._system_events["image_callback_state"] = { "time": rospy.get_time(), "value": "executed successfully", } @@ -760,7 +767,7 @@ def imagefeat_callback(self, *args): except Exception as e: traceback.print_exc() rospy.logerr(f"[{self._node_name}] error image callback", e) - self.system_events["image_callback_state"] = { + self._system_events["image_callback_state"] = { "time": rospy.get_time(), "value": f"failed to execute {e}", } @@ -775,14 +782,14 @@ def visualize_supervision(self): now = rospy.Time.now() supervision_graph_msg = Path() - supervision_graph_msg.header.frame_id = self.ros_params.fixed_frame + supervision_graph_msg.header.frame_id = self._ros_params.fixed_frame supervision_graph_msg.header.stamp = now # Footprints footprints_marker = Marker() footprints_marker.id = 0 footprints_marker.ns = "footprints" - footprints_marker.header.frame_id = self.ros_params.fixed_frame + footprints_marker.header.frame_id = self._ros_params.fixed_frame footprints_marker.header.stamp = now footprints_marker.type = Marker.TRIANGLE_LIST footprints_marker.action = Marker.ADD @@ -796,16 +803,16 @@ def visualize_supervision(self): footprints_marker.pose.position.z = 0.0 last_points = [None, None] - for node in self.traversability_estimator.get_supervision_nodes(): + for node in self._traversability_estimator.get_supervision_nodes(): # Path pose = PoseStamped() pose.header.stamp = now - pose.header.frame_id = self.ros_params.fixed_frame + pose.header.frame_id = self._ros_params.fixed_frame pose.pose = rc.torch_to_ros_pose(node.pose_base_in_world) supervision_graph_msg.poses.append(pose) # Color for traversability - r, g, b, _ = self.color_palette(node.traversability.item()) + r, g, b, _ = self._color_palette(node.traversability.item()) c = ColorRGBA(r, g, b, 0.95) # Rainbow path @@ -869,15 +876,15 @@ def visualize_supervision(self): # Publish if len(footprints_marker.points) % 3 != 0: - if self.ros_params.verbose: + if self._ros_params.verbose: rospy.loginfo(f"[{self._node_name}] number of points for footprint is {len(footprints_marker.points)}") return - self.pub_graph_footprints.publish(footprints_marker) - self.pub_debug_supervision_graph.publish(supervision_graph_msg) + self._pub_graph_footprints.publish(footprints_marker) + self._pub_debug_supervision_graph.publish(supervision_graph_msg) # Publish latest traversability - self.pub_instant_traversability.publish(self.supervision_generator.traversability) - self.system_events["visualize_supervision"] = { + self._pub_instant_traversability.publish(self._supervision_generator.traversability) + self._system_events["visualize_supervision"] = { "time": rospy.get_time(), "value": f"executed successfully", } @@ -890,24 +897,24 @@ def visualize_mission_graph(self): # Publish mission graph mission_graph_msg = Path() - mission_graph_msg.header.frame_id = self.ros_params.fixed_frame + mission_graph_msg.header.frame_id = self._ros_params.fixed_frame mission_graph_msg.header.stamp = now - for node in self.traversability_estimator.get_mission_nodes(): + for node in self._traversability_estimator.get_mission_nodes(): pose = PoseStamped() pose.header.stamp = now - pose.header.frame_id = self.ros_params.fixed_frame + pose.header.frame_id = self._ros_params.fixed_frame pose.pose = rc.torch_to_ros_pose(node.pose_cam_in_world) mission_graph_msg.poses.append(pose) - self.pub_mission_graph.publish(mission_graph_msg) + self._pub_mission_graph.publish(mission_graph_msg) @accumulate_time def visualize_image_overlay(self): """Publishes all the debugging, slow visualizations""" # Get visualization node - vis_node = self.traversability_estimator.get_mission_node_for_visualization() + vis_node = self._traversability_estimator.get_mission_node_for_visualization() # Publish reprojections of last node in graph if vis_node is not None: @@ -918,7 +925,7 @@ def visualize_image_overlay(self): torch_mask = torch_mask.float() image_out = self._visualizer.plot_detectron_classification(torch_image, torch_mask, cmap="Greens") - self.camera_handler[cam]["debug"]["image_overlay"].publish(rc.numpy_to_ros_image(image_out)) + self._camera_handler[cam]["debug"]["image_overlay"].publish(rc.numpy_to_ros_image(image_out)) if __name__ == "__main__":