From af50296f8f005744748123fc6156335c5035ba13 Mon Sep 17 00:00:00 2001 From: Jonas Frey Date: Sat, 17 Feb 2024 22:16:15 +0100 Subject: [PATCH] confidence generation certainly broken --- .deprecated/dataset/__init__.py | 2 +- .../utils/confidence_generator.py | 14 ++++++++++++++ .../scripts/wvn_feature_extractor_node.py | 10 +++++++--- .../scripts/wvn_learning_node.py | 6 +++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/.deprecated/dataset/__init__.py b/.deprecated/dataset/__init__.py index cf877c91..7e048bc8 100644 --- a/.deprecated/dataset/__init__.py +++ b/.deprecated/dataset/__init__.py @@ -1 +1 @@ -from .graph_trav_dataset import get_ablation_module \ No newline at end of file +from .graph_trav_dataset import get_ablation_module diff --git a/wild_visual_navigation/utils/confidence_generator.py b/wild_visual_navigation/utils/confidence_generator.py index 2a013e7b..e0803366 100644 --- a/wild_visual_navigation/utils/confidence_generator.py +++ b/wild_visual_navigation/utils/confidence_generator.py @@ -110,6 +110,12 @@ def update_running_mean(self, x: torch.tensor, x_positive: torch.tensor): confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5) confidence[x < self.mean] = 1.0 + # shifted_mean = self.mean + self.std*self.std_factor + # interval_min = shifted_mean - 2 * self.std + # interval_max = shifted_mean + 2 * self.std + # x = torch.clip( x , interval_min, interval_max) + # confidence = 1 - ((x - interval_min) / (interval_max - interval_min)) + return confidence.type(torch.float32) def update_moving_average(self, x: torch.tensor, x_positive: torch.tensor): @@ -184,8 +190,16 @@ def inference_without_update(self, x: torch.tensor): if self.anomaly_detection: x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std) confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x)) + else: + # shifted_mean = self.mean + self.std*self.std_factor + # interval_min = shifted_mean - 2 * self.std + # interval_max = shifted_mean + 2 * self.std + # x = torch.clip( x , interval_min, interval_max) + # confidence = 1 - ((x - interval_min) / (interval_max - interval_min)) + confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5) + confidence[x < self.mean] = 1.0 return confidence.type(torch.float32) diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 049e8b57..f72f4fd5 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -390,7 +390,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo except Exception as e: traceback.print_exc() - rospy.logerr(f"[{self._node_name}] error image callback", e) + rospy.logerr(f"[self._node_name] error image callback", e) self.system_events["image_callback_state"] = { "time": rospy.get_time(), "value": f"failed to execute {e}", @@ -435,14 +435,18 @@ def load_model(self, stamp): self._confidence_generator.std = cg["std"] else: if self._ros_params.verbose: - rospy.logerr(f"[{self._node_name}] Model Loading Failed: {e}") + rospy.logerr(f"[{self._node_name}] Model Loading Failed") if __name__ == "__main__": node_name = "wvn_feature_extractor_node" rospy.init_node(node_name) - reload_rosparams(enabled=rospy.get_param("~reload_default_params", True), node_name=node_name, camera_cfg="hdr") + reload_rosparams( + enabled=rospy.get_param("~reload_default_params", True), + node_name=node_name, + camera_cfg="wide_angle_dual_resize", + ) wvn = WvnFeatureExtractor(node_name) rospy.spin() diff --git a/wild_visual_navigation_ros/scripts/wvn_learning_node.py b/wild_visual_navigation_ros/scripts/wvn_learning_node.py index f528cad2..0cd24cdf 100644 --- a/wild_visual_navigation_ros/scripts/wvn_learning_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_learning_node.py @@ -943,6 +943,10 @@ def query_tf(self, parent_frame: str, child_frame: str, stamp: Optional[rospy.Ti node_name = "wvn_learning_node" rospy.init_node(node_name) - reload_rosparams(enabled=rospy.get_param("~reload_default_params", True), node_name=node_name, camera_cfg="hdr") + reload_rosparams( + enabled=rospy.get_param("~reload_default_params", True), + node_name=node_name, + camera_cfg="wide_angle_dual_resize", + ) wvn = WvnLearning(node_name) rospy.spin()