From 14a486ac1e2fac921ca316f68129995c743c0250 Mon Sep 17 00:00:00 2001 From: Jonas Frey Date: Wed, 1 Nov 2023 17:39:28 +0100 Subject: [PATCH] feature extractor node working --- cfg/env/jetson.yaml | 1 - tests/test_configuration.py | 16 ++ wild_visual_navigation/cfg/__init__.py | 1 + .../cfg/experiment_params.py | 3 +- wild_visual_navigation/cfg/ros_params.py | 145 +++++++++++++ wild_visual_navigation/utils/__init__.py | 6 +- wild_visual_navigation/utils/loss.py | 3 +- wild_visual_navigation_ros/CMakeLists.txt | 5 +- .../wild_visual_navigation/default.yaml | 2 + .../scripts/rosbag_play.sh | 52 +++-- .../scripts/wvn_feature_extractor_node.py | 191 ++++++++---------- 11 files changed, 282 insertions(+), 143 deletions(-) delete mode 100644 cfg/env/jetson.yaml create mode 100644 tests/test_configuration.py create mode 100644 wild_visual_navigation/cfg/ros_params.py diff --git a/cfg/env/jetson.yaml b/cfg/env/jetson.yaml deleted file mode 100644 index 8e38b327..00000000 --- a/cfg/env/jetson.yaml +++ /dev/null @@ -1 +0,0 @@ -base: results \ No newline at end of file diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 00000000..08d15b92 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,16 @@ +from wild_visual_navigation.cfg import RosLearningNodeParams +from omegaconf import OmegaConf +from omegaconf import read_write + + +def test_configuration(): + cfg = OmegaConf.structured(RosLearningNodeParams) + print(cfg) + with read_write(cfg): + cfg.image_callback_rate = 1.0 + + print(cfg.image_callback_rate) + + +if __name__ == "__main__": + test_configuration() diff --git a/wild_visual_navigation/cfg/__init__.py b/wild_visual_navigation/cfg/__init__.py index 829340ea..a1857c40 100644 --- a/wild_visual_navigation/cfg/__init__.py +++ b/wild_visual_navigation/cfg/__init__.py @@ -1 +1,2 @@ from .experiment_params import ExperimentParams +from .ros_params import RosLearningNodeParams, RosFeatureExtractorNodeParams diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index 5535b25d..5ef26107 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -1,10 +1,9 @@ from dataclasses import dataclass, field from typing import Tuple, Dict, List, Optional -from simple_parsing.helpers import Serializable @dataclass -class ExperimentParams(Serializable): +class ExperimentParams: @dataclass class GeneralParams: name: str = "debug/debug" diff --git a/wild_visual_navigation/cfg/ros_params.py b/wild_visual_navigation/cfg/ros_params.py new file mode 100644 index 00000000..c3e8068b --- /dev/null +++ b/wild_visual_navigation/cfg/ros_params.py @@ -0,0 +1,145 @@ +from dataclasses import dataclass +from typing import Dict +from typing import Any + + +@dataclass +class RosLearningNodeParams: + # TODO remove all unnecessary topics here + # Input topics + robot_state_topic: str + desired_twist_topic: str + # desired_twist_topic: "/log/state/desiredRobotTwist" + + # Relevant frames + fixed_frame: str + base_frame: str + footprint_frame: str + + # Robot size + robot_length: float + robot_width: float + robot_height: float + + # Robot specs + robot_max_velocity: float + + # Traversability estimation params + traversability_radius: float # meters + image_graph_dist_thr: float # meters + proprio_graph_dist_thr: float # meters + network_input_image_height: int # 448 + network_input_image_width: int # 448 + segmentation_type: str + feature_type: str + dino_patch_size: int # 8 or 16; 8 is finer + slic_num_components: int + dino_dim: int # 90 or 384; 384 is better + confidence_std_factor: float + scale_traversability: bool # This parameter needs to be false when using the anomaly detection model + scale_traversability_max_fpr: float + min_samples_for_training: int + prediction_per_pixel: bool + traversability_threshold: float + clip_to_binary: bool + + # Supervision Generator + untraversable_thr: float + + mission_name: str + mission_timestamp: bool + + # Threads + image_callback_rate: float # hertz + proprio_callback_rate: float # hertz + learning_thread_rate: float # hertz + logging_thread_rate: float # hertz + status_thread_rate: float # hertz + + # Runtime options + device: str + mode: str # check out comments in the class WVNMode + colormap: str + + print_image_callback_time: bool + print_proprio_callback_time: bool + log_time: bool + log_confidence: bool + verbose: bool + debug_supervision_node_index_from_last: int + use_debug_for_desired: bool + + extraction_store_folder: str + exp: str + use_binary_only: bool + + +@dataclass +class RosFeatureExtractorNodeParams: + # Input topics + robot_state_topic: str + desired_twist_topic: str + # desired_twist_topic: "/log/state/desiredRobotTwist" + + # Relevant frames + fixed_frame: str + base_frame: str + footprint_frame: str + + # Robot size + robot_length: float + robot_width: float + robot_height: float + + # Robot specs + robot_max_velocity: float + + # Traversability estimation params + traversability_radius: float # meters + image_graph_dist_thr: float # meters + proprio_graph_dist_thr: float # meters + network_input_image_height: int # 448 + network_input_image_width: int # 448 + segmentation_type: str + feature_type: str + dino_patch_size: int # 8 or 16; 8 is finer + slic_num_components: int + dino_dim: int # 90 or 384; 384 is better + confidence_std_factor: float + scale_traversability: bool # This parameter needs to be false when using the anomaly detection model + scale_traversability_max_fpr: float + min_samples_for_training: int + prediction_per_pixel: bool + traversability_threshold: float + clip_to_binary: bool + + # Supervision Generator + untraversable_thr: float + + mission_name: str + mission_timestamp: bool + + # Threads + image_callback_rate: float # hertz + proprio_callback_rate: float # hertz + learning_thread_rate: float # hertz + logging_thread_rate: float # hertz + status_thread_rate: float # hertz + + # Runtime options + device: str + mode: str # check out comments in the class WVNMode + colormap: str + + print_image_callback_time: bool + print_proprio_callback_time: bool + log_time: bool + log_confidence: bool + verbose: bool + debug_supervision_node_index_from_last: int + use_debug_for_desired: bool + + extraction_store_folder: str + exp: str + use_binary_only: bool + camera_topics: Dict[str, Any] diff --git a/wild_visual_navigation/utils/__init__.py b/wild_visual_navigation/utils/__init__.py index a6ae229c..b502726d 100644 --- a/wild_visual_navigation/utils/__init__.py +++ b/wild_visual_navigation/utils/__init__.py @@ -3,13 +3,13 @@ from .loading import load_env, load_yaml, file_path from .create_experiment_folder import create_experiment_folder from .get_confidence import get_confidence -from .loss import TraversabilityLoss, AnomalyLoss -from .metric_logger import MetricLogger from .kalman_filter import KalmanFilter +from .confidence_generator import ConfidenceGenerator +from .metric_logger import MetricLogger from .meshes import make_box, make_rounded_box, make_ellipsoid, make_plane, make_polygon_from_points, make_dense_plane from .klt_tracker import KLTTracker, KLTTrackerOpenCV -from .confidence_generator import ConfidenceGenerator from .operation_modes import WVNMode from .dataset_info import perugia_dataset, ROOT_DIR from .override_params import override_params from .gpu_monitor import GpuMonitor, SystemLevelGpuMonitor, SystemLevelContextGpuMonitor, accumulate_memory +from .loss import TraversabilityLoss, AnomalyLoss diff --git a/wild_visual_navigation/utils/loss.py b/wild_visual_navigation/utils/loss.py index f4aa6837..65c36c26 100644 --- a/wild_visual_navigation/utils/loss.py +++ b/wild_visual_navigation/utils/loss.py @@ -1,8 +1,9 @@ +from wild_visual_navigation.utils import ConfidenceGenerator + import torch.nn.functional as F from torch_geometric.data import Data import torch from typing import Optional -from wild_visual_navigation.utils import ConfidenceGenerator from torch import nn diff --git a/wild_visual_navigation_ros/CMakeLists.txt b/wild_visual_navigation_ros/CMakeLists.txt index 0435751a..908adcda 100644 --- a/wild_visual_navigation_ros/CMakeLists.txt +++ b/wild_visual_navigation_ros/CMakeLists.txt @@ -14,12 +14,9 @@ catkin_package( ) catkin_python_setup() -catkin_install_python(PROGRAMS scripts/wild_visual_navigation_node.py - scripts/wvn_feature_extractor_node.py +catkin_install_python(PROGRAMS scripts/wvn_feature_extractor_node.py scripts/wvn_learning_node.py scripts/overlay_images.py - scripts/shift_gridmap.py scripts/smart_carrot.py scripts/rosbag_play.sh - scripts/rotate_image.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}) diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 71f65689..00f1f9b8 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -1,3 +1,5 @@ +# TODO split into three configuration files - shared - learning - feature_extractor + # Input topics robot_state_topic: "/wild_visual_navigation_node/robot_state" desired_twist_topic: "/motion_reference/command_twist" diff --git a/wild_visual_navigation_ros/scripts/rosbag_play.sh b/wild_visual_navigation_ros/scripts/rosbag_play.sh index 902b2470..917a2852 100755 --- a/wild_visual_navigation_ros/scripts/rosbag_play.sh +++ b/wild_visual_navigation_ros/scripts/rosbag_play.sh @@ -6,35 +6,33 @@ for option in "$@"; do args="$args /elevation_mapping/elevation_map_raw:=/recorded/elevation_mapping/elevation_map_raw \ /elevation_mapping/semantic_map_raw:=/recorded/elevation_mapping/semantic_map_raw" elif [ "$option" == "--wvn" ]; then - args="$args /wild_visual_navigation_node/front/camera_info:=/recorded_wvn/wild_visual_navigation_node/front/camera_info \ - /wild_visual_navigation_node/front/confidence:=/recorded_wvn/wild_visual_navigation_node/front/confidence \ - /wild_visual_navigation_node/front/image_input:=/recorded_wvn/wild_visual_navigation_node/front/image_input \ - /wild_visual_navigation_node/front/traversability:=/recorded_wvn/wild_visual_navigation_node/front/traversability \ - /wild_visual_navigation_node/graph_footprints:=/recorded_wvn/wild_visual_navigation_node/graph_footprints \ - /wild_visual_navigation_node/instant_traversability:=/recorded_wvn/wild_visual_navigation_node/instant_traversability \ - /wild_visual_navigation_node/proprioceptive_graph:=/recorded_wvn/wild_visual_navigation_node/proprioceptive_graph \ - /wild_visual_navigation_node/robot_state:=/recorded_wvn/wild_visual_navigation_node/robot_state \ - /wild_visual_navigation_node/system_state:=/recorded_wvn/wild_visual_navigation_node/system_state \ - /wild_visual_navigation_visu_confidence/confidence_overlayed:=/recorded_wvn/wild_visual_navigation_visu_confidence/confidence_overlayed \ - /wild_visual_navigation_visu_traversability/traversability_overlayed:=/recorded_wvn/wild_visual_navigation_visu_traversability/traversability_overlayed" + args="$args /wild_visual_navigation_node/front/camera_info:=/recorded/wild_visual_navigation_node/front/camera_info \ + /wild_visual_navigation_node/front/confidence:=/recorded/wild_visual_navigation_node/front/confidence \ + /wild_visual_navigation_node/front/image_input:=/recorded/wild_visual_navigation_node/front/image_input \ + /wild_visual_navigation_node/front/traversability:=/recorded/wild_visual_navigation_node/front/traversability \ + /wild_visual_navigation_node/graph_footprints:=/recorded/wild_visual_navigation_node/graph_footprints \ + /wild_visual_navigation_node/instant_traversability:=/recorded/wild_visual_navigation_node/instant_traversability \ + /wild_visual_navigation_node/proprioceptive_graph:=/recorded/wild_visual_navigation_node/proprioceptive_graph \ + /wild_visual_navigation_node/robot_state:=/recorded/wild_visual_navigation_node/robot_state \ + /wild_visual_navigation_node/system_state:=/recorded/wild_visual_navigation_node/system_state \ + /wild_visual_navigation_visu_confidence/confidence_overlayed:=/recorded/wild_visual_navigation_visu_confidence/confidence_overlayed \ + /wild_visual_navigation_visu_traversability/traversability_overlayed:=/recorded/wild_visual_navigation_visu_traversability/traversability_overlayed" elif [ "$option" == "--flp" ]; then - args="$args /field_local_planner/action_server/status:=/recorded_flp/field_local_planner/action_server/status \ - /field_local_planner/current_base:=/recorded_flp/field_local_planner/current_base \ - /field_local_planner/current_goal:=/recorded_flp/field_local_planner/current_goal \ - /field_local_planner/parameter_descriptions:=/recorded_flp/field_local_planner/parameter_descriptions \ - /field_local_planner/parameter_updates:=/recorded_flp/field_local_planner/parameter_updates \ - /field_local_planner/path:=/recorded_flp/field_local_planner/path \ - /field_local_planner/real_carrot:=/recorded_flp/field_local_planner/real_carrot \ - /field_local_planner/rmp/control_points:=/recorded_flp/field_local_planner/rmp/control_points \ - /field_local_planner/rmp/parameter_descriptions:=/recorded_flp/field_local_planner/rmp/parameter_descriptions \ - /field_local_planner/rmp/parameter_updates:=/recorded_flp/field_local_planner/rmp/parameter_updates \ - /field_local_planner/status:=/recorded_flp/field_local_planner/status \ - /elevation_mapping/elevation_map_wifi:=/recorded_flp/elevation_mapping/elevation_map_wifi" + args="$args /field_local_planner/action_server/status:=/recorded/field_local_planner/action_server/status \ + /field_local_planner/current_base:=/recorded/field_local_planner/current_base \ + /field_local_planner/current_goal:=/recorded/field_local_planner/current_goal \ + /field_local_planner/parameter_descriptions:=/recorded/field_local_planner/parameter_descriptions \ + /field_local_planner/parameter_updates:=/recorded/field_local_planner/parameter_updates \ + /field_local_planner/path:=/recorded/field_local_planner/path \ + /field_local_planner/real_carrot:=/recorded/field_local_planner/real_carrot \ + /field_local_planner/rmp/control_points:=/recorded/field_local_planner/rmp/control_points \ + /field_local_planner/rmp/parameter_descriptions:=/recorded/field_local_planner/rmp/parameter_descriptions \ + /field_local_planner/rmp/parameter_updates:=/recorded/field_local_planner/rmp/parameter_updates \ + /field_local_planner/status:=/recorded/field_local_planner/status \ + /elevation_mapping/elevation_map_wifi:=/recorded/elevation_mapping/elevation_map_wifi" elif [ "$option" == "--tf" ]; then - args="$args /tf:=/recorded_flp/tf" - - # /tf_static:=/recorded_flp/tf_static" - + args="$args /tf:=/recorded/tf" + # /tf_static:=/recorded/tf_static" echo "rosrun anymal_rsl_launch replay.py c /media/Data/Datasets/2023_Oxford_Testing/2023_01_27_Oxford_Park/mission_data/2023-01-27-11-00-22/2023-01-27-11-00-22_anymal-coyote-lpc_mission.yaml" else args="$args $option" diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 3d502213..ee941a46 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -1,8 +1,6 @@ from wild_visual_navigation import WVN_ROOT_DIR -from wild_visual_navigation.utils import load_yaml, load_env, create_experiment_folder from wild_visual_navigation.feature_extractor import FeatureExtractor -from wild_visual_navigation.cfg import ExperimentParams -from wild_visual_navigation.utils import override_params +from wild_visual_navigation.cfg import ExperimentParams, RosFeatureExtractorNodeParams from wild_visual_navigation.image_projector import ImageProjector from wild_visual_navigation_msgs.msg import ImageFeatures import wild_visual_navigation_ros.ros_converter as rc @@ -17,7 +15,7 @@ import os import torch import numpy as np -import dataclasses +from omegaconf import OmegaConf, read_write from torch_geometric.data import Data import torch.nn.functional as F from threading import Thread, Event @@ -31,38 +29,37 @@ class WvnFeatureExtractor: def __init__(self): # Read params self.read_params() - self.anomaly_detection = self.exp_cfg["model"]["name"] == "LinearRnvp" self.feature_extractor = FeatureExtractor( - self.device, - segmentation_type=self.segmentation_type, - feature_type=self.feature_type, - input_size=self.network_input_image_height, - slic_num_components=self.slic_num_components, - dino_dim=self.dino_dim, + self.ros_params.device, + segmentation_type=self.ros_params.segmentation_type, + feature_type=self.ros_params.feature_type, + input_size=self.ros_params.network_input_image_height, + slic_num_components=self.ros_params.slic_num_components, + dino_dim=self.ros_params.dino_dim, ) self.i = 0 - self.model = get_model(self.exp_cfg["model"]).to(self.device) + self.model = get_model(self.params.model).to(self.ros_params.device) self.model.eval() if not self.anomaly_detection: self.confidence_generator = ConfidenceGenerator( - method=self.exp_cfg["loss"]["method"], - std_factor=self.exp_cfg["loss"]["confidence_std_factor"], + method=self.params.loss.method, + std_factor=self.params.loss.confidence_std_factor, anomaly_detection=self.anomaly_detection, ) - self.scale_traversability = True + self.ros_params.scale_traversability = True else: self.traversability_loss = AnomalyLoss( - **self.exp_cfg["loss_anomaly"], - log_enabled=self.exp_cfg["general"]["log_confidence"], - log_folder=self.exp_cfg["general"]["model_path"], + **self.params.loss_anomaly, + log_enabled=self.params.general.log_confidence, + log_folder=self.params.general.model_path, ) - self.traversability_loss.to(self.device) - self.scale_traversability = False + self.traversability_loss.to(self.ros_params.device) + self.ros_params.scale_traversability = False - if self.verbose: + if self.ros_params.verbose: self.log_data = {} self.status_thread_stop_event = Event() self.status_thread = Thread(target=self.status_thread_loop, name="status") @@ -84,7 +81,7 @@ def shutdown_callback(self, *args, **kwargs): sys.exit(0) def status_thread_loop(self): - rate = rospy.Rate(self.status_thread_rate) + rate = rospy.Rate(self.ros_params.status_thread_rate) # Learning loop while self.run_status_thread: @@ -115,44 +112,24 @@ def status_thread_loop(self): # try: # rate.sleep() # except Exception as e: - # rate = rospy.Rate(self.status_thread_rate) + # rate = rospy.Rate(self.ros_params.status_thread_rate) # print("Ignored jump pack in time!") self.status_thread_stop_event.clear() def read_params(self): """Reads all the parameters from the parameter server""" - self.device = rospy.get_param("~device") - self.verbose = rospy.get_param("~verbose") - - # Topics - self.camera_topics = rospy.get_param("~camera_topics") - # Experiment file - self.network_input_image_height = rospy.get_param("~network_input_image_height") - self.network_input_image_width = rospy.get_param("~network_input_image_width") - - self.segmentation_type = rospy.get_param("~segmentation_type") - self.feature_type = rospy.get_param("~feature_type") - self.dino_patch_size = rospy.get_param("~dino_patch_size") - self.dino_dim = rospy.get_param("~dino_dim") - self.slic_num_components = rospy.get_param("~slic_num_components") - self.traversability_threshold = rospy.get_param("~traversability_threshold") - self.clip_to_binary = rospy.get_param("~clip_to_binary") - - self.confidence_std_factor = rospy.get_param("~confidence_std_factor") - self.scale_traversability = rospy.get_param("~scale_traversability") - self.scale_traversability_max_fpr = rospy.get_param("~scale_traversability_max_fpr") - self.status_thread_rate = rospy.get_param("~status_thread_rate") - self.prediction_per_pixel = rospy.get_param("~prediction_per_pixel") - # Initialize traversability estimator parameters - # Experiment file - exp_file = rospy.get_param("~exp") - self.params = ExperimentParams() - if exp_file != "nan": - exp_override = load_yaml(os.path.join(WVN_ROOT_DIR, "cfg/exp", exp_file)) - self.params = override_params(self.params, exp_override) - - self.exp_cfg = dataclasses.asdict(self.params) - self.exp_cfg["loss"]["confidence_std_factor"] = self.confidence_std_factor + self.params = OmegaConf.structured(ExperimentParams) + self.ros_params = OmegaConf.structured(RosFeatureExtractorNodeParams) + + # Override the empty dataclass with values from rosparm server + with read_write(self.ros_params): + for k in self.ros_params.keys(): + self.ros_params[k] = rospy.get_param(f"~{k}") + + with read_write(self.params): + self.params.loss.confidence_std_factor = self.ros_params.confidence_std_factor + + self.anomaly_detection = self.params.model.name == "LinearRnvp" def setup_ros(self, setup_fully=True): """Main function to setup ROS-related stuff: publishers, subscribers and services""" @@ -160,13 +137,13 @@ def setup_ros(self, setup_fully=True): self.camera_handler = {} - if self.verbose: + if self.ros_params.verbose: # DEBUG Logging self.log_data[f"time_last_model"] = -1 self.log_data[f"nr_model_updates"] = -1 - for cam in self.camera_topics: - if self.verbose: + for cam in self.ros_params.camera_topics: + if self.ros_params.verbose: # DEBUG Logging self.log_data[f"nr_images_{cam}"] = 0 self.log_data[f"time_last_image_{cam}"] = -1 @@ -174,40 +151,43 @@ def setup_ros(self, setup_fully=True): # Initialize camera handler for given cam self.camera_handler[cam] = {} # Store camera name - self.camera_topics[cam]["name"] = cam + self.ros_params.camera_topics[cam]["name"] = cam # Camera info - camera_info_msg = rospy.wait_for_message(self.camera_topics[cam]["info_topic"], CameraInfo, timeout=15) - self.camera_topics[cam]["camera_info"] = camera_info_msg + camera_info_msg = rospy.wait_for_message( + self.ros_params.camera_topics[cam]["info_topic"], CameraInfo, timeout=15 + ) + K, H, W = rc.ros_cam_info_to_tensors(camera_info_msg, device=self.ros_params.device) - K, H, W = rc.ros_cam_info_to_tensors(camera_info_msg, device=self.device) - self.camera_topics[cam]["K"] = K - self.camera_topics[cam]["H"] = H - self.camera_topics[cam]["W"] = W + self.camera_handler[cam]["camera_info"] = camera_info_msg + self.camera_handler[cam]["K"] = K + self.camera_handler[cam]["H"] = H + self.camera_handler[cam]["W"] = W image_projector = ImageProjector( - K=self.camera_topics[cam]["K"], - h=self.camera_topics[cam]["H"], - w=self.camera_topics[cam]["W"], - new_h=self.network_input_image_height, - new_w=self.network_input_image_width, + K=self.camera_handler[cam]["K"], + h=self.camera_handler[cam]["H"], + w=self.camera_handler[cam]["W"], + new_h=self.ros_params.network_input_image_height, + new_w=self.ros_params.network_input_image_width, ) - msg = self.camera_topics[cam]["camera_info"] - msg.width = self.network_input_image_width - msg.height = self.network_input_image_height + msg = self.camera_handler[cam]["camera_info"] + msg.width = self.ros_params.network_input_image_width + msg.height = self.ros_params.network_input_image_height msg.K = image_projector.scaled_camera_matrix[0, :3, :3].cpu().numpy().flatten().tolist() msg.P = image_projector.scaled_camera_matrix[0, :3, :4].cpu().numpy().flatten().tolist() - self.camera_topics[cam]["camera_info_msg_out"] = msg - self.camera_topics[cam]["image_projector"] = image_projector + with read_write(self.ros_params): + self.camera_handler[cam]["camera_info_msg_out"] = msg + self.camera_handler[cam]["image_projector"] = image_projector # Set subscribers - base_topic = self.camera_topics[cam]["image_topic"].replace("/compressed", "") - is_compressed = self.camera_topics[cam]["image_topic"] != base_topic + base_topic = self.ros_params.camera_topics[cam]["image_topic"].replace("/compressed", "") + is_compressed = self.ros_params.camera_topics[cam]["image_topic"] != base_topic if is_compressed: # TODO study the effect of the buffer size image_sub = rospy.Subscriber( - self.camera_topics[cam]["image_topic"], + self.ros_params.camera_topics[cam]["image_topic"], CompressedImage, self.image_callback, callback_args=cam, @@ -215,7 +195,11 @@ def setup_ros(self, setup_fully=True): ) else: image_sub = rospy.Subscriber( - self.camera_topics[cam]["image_topic"], Image, self.image_callback, callback_args=cam, queue_size=1 + self.ros_params.camera_topics[cam]["image_topic"], + Image, + self.image_callback, + callback_args=cam, + queue_size=1, ) self.camera_handler[cam]["image_sub"] = image_sub @@ -224,19 +208,19 @@ def setup_ros(self, setup_fully=True): info_pub = rospy.Publisher(f"/wild_visual_navigation_node/{cam}/camera_info", CameraInfo, queue_size=10) self.camera_handler[cam]["trav_pub"] = trav_pub self.camera_handler[cam]["info_pub"] = info_pub - if self.anomaly_detection and self.camera_topics[cam]["publish_confidence"]: + if self.anomaly_detection and self.ros_params.camera_topics[cam]["publish_confidence"]: print(colored("Warning force set public confidence to false", "red")) - self.camera_topics[cam]["publish_confidence"] = False + self.ros_params.camera_topics[cam]["publish_confidence"] = False - if self.camera_topics[cam]["publish_input_image"]: + if self.ros_params.camera_topics[cam]["publish_input_image"]: input_pub = rospy.Publisher(f"/wild_visual_navigation_node/{cam}/image_input", Image, queue_size=10) self.camera_handler[cam]["input_pub"] = input_pub - if self.camera_topics[cam]["publish_confidence"]: + if self.ros_params.camera_topics[cam]["publish_confidence"]: conf_pub = rospy.Publisher(f"/wild_visual_navigation_node/{cam}/confidence", Image, queue_size=10) self.camera_handler[cam]["conf_pub"] = conf_pub - if self.camera_topics[cam]["use_for_training"]: + if self.ros_params.camera_topics[cam]["use_for_training"]: imagefeat_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/feat", ImageFeatures, queue_size=10 ) @@ -251,7 +235,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo info_msg (sensor_msgs/CameraInfo): Camera info message associated to the image cam (str): Camera name """ - if self.verbose: + if self.ros_params.verbose: # DEBUG Logging self.log_data[f"nr_images_{cam}"] += 1 self.log_data[f"time_last_image_{cam}"] = rospy.get_time() @@ -259,8 +243,8 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo # Update model from file if possible self.load_model() # Convert image message to torch image - torch_image = rc.ros_image_to_torch(image_msg, device=self.device) - torch_image = self.camera_topics[cam]["image_projector"].resize_image(torch_image) + torch_image = rc.ros_image_to_torch(image_msg, device=self.ros_params.device) + torch_image = self.camera_handler[cam]["image_projector"].resize_image(torch_image) C, H, W = torch_image.shape _, feat, seg, center, dense_feat = self.feature_extractor.extract( @@ -270,7 +254,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo n_random_pixels=100, ) - if self.prediction_per_pixel: + if self.ros_params.prediction_per_pixel: # Evaluate traversability data = Data(x=dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])) else: @@ -286,15 +270,15 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo out_trav = prediction.reshape(H, W, -1)[:, :, 0] # Publish traversability - if self.scale_traversability: + if self.ros_params.scale_traversability: # Apply piecewise linear scaling 0->0; threshold->0.5; 1->1 traversability = out_trav.clone() - m = traversability < self.traversability_threshold + m = traversability < self.ros_params.traversability_threshold # Scale untraversable - traversability[m] *= 0.5 / self.traversability_threshold + traversability[m] *= 0.5 / self.ros_params.traversability_threshold # Scale traversable - traversability[~m] -= self.traversability_threshold - traversability[~m] *= 0.5 / (1 - self.traversability_threshold) + traversability[~m] -= self.ros_params.traversability_threshold + traversability[~m] *= 0.5 / (1 - self.ros_params.traversability_threshold) traversability[~m] += 0.5 traversability = traversability.clip(0, 1) # TODO Check if this was a bug @@ -305,8 +289,8 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo out_trav = trav.reshape(H, W, -1)[:, :, 0] # Clip to binary output - if self.clip_to_binary: - out_trav = torch.where(out_trav.squeeze() <= self.traversability_threshold, 0.0, 1.0) + if self.ros_params.clip_to_binary: + out_trav = torch.where(out_trav.squeeze() <= self.ros_params.traversability_threshold, 0.0, 1.0) msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough") msg.header = image_msg.header @@ -314,12 +298,12 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo msg.height = out_trav.shape[1] self.camera_handler[cam]["trav_pub"].publish(msg) - msg = self.camera_topics[cam]["camera_info_msg_out"] + msg = self.camera_handler[cam]["camera_info_msg_out"] msg.header = image_msg.header self.camera_handler[cam]["info_pub"].publish(msg) # Publish image - if self.camera_topics[cam]["publish_input_image"]: + if self.ros_params.camera_topics[cam]["publish_input_image"]: msg = rc.numpy_to_ros_image((torch_image.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8), "rgb8") msg.header = image_msg.header msg.width = torch_image.shape[1] @@ -327,7 +311,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo self.camera_handler[cam]["input_pub"].publish(msg) # Publish confidence - if self.camera_topics[cam]["publish_confidence"]: + if self.ros_params.camera_topics[cam]["publish_confidence"]: loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1) confidence = self.confidence_generator.inference_without_update(x=loss_reco) out_confidence = confidence.reshape(H, W) @@ -338,7 +322,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo self.camera_handler[cam]["conf_pub"].publish(msg) # Publish features and feature_segments - if self.camera_topics[cam]["use_for_training"]: + if self.ros_params.camera_topics[cam]["use_for_training"]: msg = ImageFeatures() msg.header = image_msg.header msg.feature_segments = rc.numpy_to_ros_image(seg.cpu().numpy().astype(np.int32), "passthrough") @@ -368,7 +352,7 @@ def load_model(self): k = list(self.model.state_dict().keys())[-1] if (self.model.state_dict()[k] != res[k]).any(): - if self.verbose: + if self.ros_params.verbose: self.log_data[f"time_last_model"] = rospy.get_time() self.log_data[f"nr_model_updates"] += 1 @@ -376,7 +360,7 @@ def load_model(self): try: if res["traversability_threshold"] is not None: - self.traversability_threshold = res["traversability_threshold"] + self.ros_params.traversability_threshold = res["traversability_threshold"] if res["confidence_generator"] is not None: self.confidence_generator_state = res["confidence_generator"] @@ -388,7 +372,7 @@ def load_model(self): pass except Exception as e: - if self.verbose: + if self.ros_params.verbose: print(f"Model Loading Failed: {e}") @@ -405,9 +389,6 @@ def load_model(self): os.system( f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml wvn_feature_extractor_node" ) - print( - f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml wvn_feature_extractor_node" - ) wvn = WvnFeatureExtractor() rospy.spin()