Skip to content

Commit

Permalink
confidence generation certainly broken
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 17, 2024
1 parent a7e5ea6 commit af50296
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .deprecated/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .graph_trav_dataset import get_ablation_module
from .graph_trav_dataset import get_ablation_module
14 changes: 14 additions & 0 deletions wild_visual_navigation/utils/confidence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def update_running_mean(self, x: torch.tensor, x_positive: torch.tensor):
confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
confidence[x < self.mean] = 1.0

# shifted_mean = self.mean + self.std*self.std_factor
# interval_min = shifted_mean - 2 * self.std
# interval_max = shifted_mean + 2 * self.std
# x = torch.clip( x , interval_min, interval_max)
# confidence = 1 - ((x - interval_min) / (interval_max - interval_min))

return confidence.type(torch.float32)

def update_moving_average(self, x: torch.tensor, x_positive: torch.tensor):
Expand Down Expand Up @@ -184,8 +190,16 @@ def inference_without_update(self, x: torch.tensor):
if self.anomaly_detection:
x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std)
confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x))

else:
# shifted_mean = self.mean + self.std*self.std_factor
# interval_min = shifted_mean - 2 * self.std
# interval_max = shifted_mean + 2 * self.std
# x = torch.clip( x , interval_min, interval_max)
# confidence = 1 - ((x - interval_min) / (interval_max - interval_min))

confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
confidence[x < self.mean] = 1.0

return confidence.type(torch.float32)

Expand Down
10 changes: 7 additions & 3 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo

except Exception as e:
traceback.print_exc()
rospy.logerr(f"[{self._node_name}] error image callback", e)
rospy.logerr(f"[self._node_name] error image callback", e)
self.system_events["image_callback_state"] = {
"time": rospy.get_time(),
"value": f"failed to execute {e}",
Expand Down Expand Up @@ -435,14 +435,18 @@ def load_model(self, stamp):
self._confidence_generator.std = cg["std"]
else:
if self._ros_params.verbose:
rospy.logerr(f"[{self._node_name}] Model Loading Failed: {e}")
rospy.logerr(f"[{self._node_name}] Model Loading Failed")


if __name__ == "__main__":
node_name = "wvn_feature_extractor_node"
rospy.init_node(node_name)

reload_rosparams(enabled=rospy.get_param("~reload_default_params", True), node_name=node_name, camera_cfg="hdr")
reload_rosparams(
enabled=rospy.get_param("~reload_default_params", True),
node_name=node_name,
camera_cfg="wide_angle_dual_resize",
)

wvn = WvnFeatureExtractor(node_name)
rospy.spin()
6 changes: 5 additions & 1 deletion wild_visual_navigation_ros/scripts/wvn_learning_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,10 @@ def query_tf(self, parent_frame: str, child_frame: str, stamp: Optional[rospy.Ti
node_name = "wvn_learning_node"
rospy.init_node(node_name)

reload_rosparams(enabled=rospy.get_param("~reload_default_params", True), node_name=node_name, camera_cfg="hdr")
reload_rosparams(
enabled=rospy.get_param("~reload_default_params", True),
node_name=node_name,
camera_cfg="wide_angle_dual_resize",
)
wvn = WvnLearning(node_name)
rospy.spin()

0 comments on commit af50296

Please sign in to comment.