Skip to content

Commit

Permalink
removed trav and cleaned up config
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 17, 2024
1 parent 96dadfa commit 6c73d70
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 322 deletions.
16 changes: 2 additions & 14 deletions wild_visual_navigation/cfg/ros_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,10 @@ class RosLearningNodeParams:
traversability_radius: float # meters
image_graph_dist_thr: float # meters
supervision_graph_dist_thr: float # meters
network_input_image_height: int # 448
network_input_image_width: int # 448
segmentation_type: str
feature_type: str
dino_patch_size: int # 8 or 16; 8 is finer
dino_backbone: str # vit_small, vit_base
slic_num_components: int
confidence_std_factor: float
scale_traversability: bool # This parameter needs to be false when using the anomaly detection model
scale_traversability_max_fpr: float
min_samples_for_training: int

traversability_threshold: float
network_input_image_height: int
network_input_image_width: int
vis_node_index: int

# Supervision Generator
Expand Down Expand Up @@ -82,12 +73,9 @@ class RosFeatureExtractorNodeParams:

# ConfidenceGenerator
confidence_std_factor: float
scale_traversability: bool # This parameter needs to be false when using the anomaly detection model

# Output setting
prediction_per_pixel: bool
traversability_threshold: float
clip_to_binary: bool

# Runtime options
mode: Any # check out comments in the class WVNMode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
class SupervisionGenerator:
def __init__(
self,
device: str = "cuda",
kf_process_cov: float = 0.01,
kf_meas_cov: float = 10,
kf_outlier_rejection: str = "none",
kf_outlier_rejection_delta: float = 1.0,
sigmoid_slope: float = 15,
sigmoid_cutoff: float = 0.2,
untraversable_thr: float = 0.1,
time_horizon: float = 1,
graph_max_length: float = 1,
device: str,
kf_process_cov,
kf_meas_cov,
kf_outlier_rejection,
kf_outlier_rejection_delta,
sigmoid_slope,
sigmoid_cutoff,
untraversable_thr,
time_horizon,
graph_max_length,
):
"""Generates traversability signals/labels from different sources
Expand Down Expand Up @@ -201,6 +201,8 @@ def run_supervision_generator():
sigmoid_slope=30,
sigmoid_cutoff=0.2,
untraversable_thr=0.05,
time_horizon=0.05,
graph_max_length=1,
)

# Saved data list
Expand Down
9 changes: 0 additions & 9 deletions wild_visual_navigation/traversability_estimator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(

# Uninitialized members
self._features = None
self._feature_type = None
self._feature_edges = None
self._feature_segments = None
self._feature_positions = None
Expand Down Expand Up @@ -258,10 +257,6 @@ def confidence(self):
def features(self):
return self._features

@property
def feature_type(self):
return self._feature_type

@property
def feature_edges(self):
return self._feature_edges
Expand Down Expand Up @@ -318,10 +313,6 @@ def confidence(self, confidence):
def features(self, features):
self._features = features

@feature_type.setter
def feature_type(self, feature_type):
self._feature_type = feature_type

@feature_edges.setter
def feature_edges(self, feature_edges):
self._feature_edges = feature_edges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import pickle
import torch
import torchvision.transforms as transforms
from torchmetrics import ROC

to_tensor = transforms.ToTensor()

Expand All @@ -30,38 +29,24 @@ class TraversabilityEstimator:
def __init__(
self,
params: ExperimentParams,
scale_traversability: bool,
device: str = "cuda",
max_distance: float = 3,
image_size: int = 448,
image_distance_thr: float = None,
supervision_distance_thr: float = None,
segmentation_type: str = "slic",
feature_type: str = "dino",
min_samples_for_training: int = 10,
vis_node_index: int = 10,
mode: bool = False,
extraction_store_folder=None,
anomaly_detection: bool = False,
vis_training_samples: bool = False,
use_feature_extractor: bool = False,
**kwargs,
device: str,
max_distance: float,
image_distance_thr: float,
supervision_distance_thr: float,
min_samples_for_training: int,
vis_node_index: int,
mode: bool,
extraction_store_folder,
anomaly_detection: bool,
):
self._device = device
self._mode = mode
self._extraction_store_folder = extraction_store_folder
self._min_samples_for_training = min_samples_for_training
self._vis_node_index = vis_node_index
self._scale_traversability = scale_traversability
self._params = params
self._scale_traversability_threshold = 0
self._anomaly_detection = anomaly_detection

if self._scale_traversability:
# Use 500 bins for constant memory usuage
self._auxiliary_training_roc = ROC(task="binary", thresholds=5000)
self._auxiliary_training_roc.to(self._device)

# Local graphs
self._supervision_graph = DistanceWindowGraph(max_distance=max_distance, edge_distance=supervision_distance_thr)

Expand All @@ -74,20 +59,6 @@ def __init__(
# Visualization node
self._vis_mission_node = None

# Feature extractor
self._segmentation_type = segmentation_type
self._feature_type = feature_type

self._use_feature_extractor = use_feature_extractor
if use_feature_extractor:
self._feature_extractor = FeatureExtractor(
self._device,
segmentation_type=self._segmentation_type,
feature_type=self._feature_type,
input_size=image_size,
**kwargs,
)

# Mutex
self._learning_lock = Lock()

Expand Down Expand Up @@ -143,48 +114,6 @@ def __setstate__(self, state: dict):

def reset(self):
print("[WARNING] Resetting the traversability estimator is not fully tested")
# with self._learning_lock:
# self._pause_training = True
# self._pause_mission_graph = True
# self._pause_supervision_graph = True
# time.sleep(2.0)

# self._supervision_graph.clear()
# self._mission_graph.clear()

# # Reset all the learning stuff
# self._step = 0
# self._loss = torch.tensor([torch.inf])

# # Re-create model
# self._params = dataclasses.asdict(self._params)
# self._model = get_model(self._params["model"]).to(self._device)
# self._model.train()

# # Re-create optimizer
# self._optimizer = torch.optim.Adam(self._model.parameters(), lr=self._params["optimizer"]["lr"])

# # Re-create traversability loss
# self._traversability_loss = TraversabilityLoss(
# **self._params["loss"],
# model=self._model,
# log_enabled=self._params["general"]["log_confidence"],
# log_folder=self._params["general"]["model_path"],
# )
# self._traversability_loss.to(self._device)

# # Resume training
# self._pause_training = False
# self._pause_mission_graph = False
# self._pause_supervision_graph = False

@property
def scale_traversability_threshold(self):
return self._scale_traversability_threshold

@scale_traversability_threshold.setter
def scale_traversability_threshold(self, scale_traversability_threshold):
self._scale_traversability_threshold = scale_traversability_threshold

@property
def loss(self):
Expand Down Expand Up @@ -215,42 +144,6 @@ def change_device(self, device: str):

if self._use_feature_extractor:
self._feature_extractor.change_device(device)
if self._scale_traversability:
# Use 500 bins for constant memory usuage
self._auxiliary_training_roc.to(device)

@accumulate_time
def update_features(self, node: MissionNode):
"""Extracts visual features from a node that stores an image
Args:
node (MissionNode): new node in the mission graph
"""
if not self._use_feature_extractor:
raise ValueError(
"Udate features can be not called given that when creating the TraversabilityEstimator the FeatureExtractor was not used: use_feature_extractor = False"
)

if self._mode != WVNMode.EXTRACT_LABELS:
# Extract features
# Check do we need to add here the .clone() in
edges, feat, seg, center = self._feature_extractor.extract(img=node.image[None], return_centers=True)

# Set features in node
node.feature_type = self._feature_extractor.feature_type
node.features = feat
node.feature_edges = edges
node.feature_segments = seg
node.feature_positions = center

@accumulate_time
def update_prediction(self, node: MissionNode):
data = Data(x=node.features, edge_index=node.feature_edges)
with torch.inference_mode():
with self._learning_lock:
node.prediction = self._model(data)
# TODO Check where node confidence is actually used
self._traversability_loss.update_node_confidence(node)

@accumulate_time
def update_visualization_node(self):
Expand All @@ -265,7 +158,7 @@ def update_visualization_node(self):
self._vis_mission_node = self._mission_graph.get_nodes()[-self._vis_node_index]

@accumulate_time
def add_mission_node(self, node: MissionNode, verbose: bool = False, update_features: bool = True):
def add_mission_node(self, node: MissionNode, verbose: bool = False):
"""Adds a node to the mission graph to images and training info
Args:
Expand All @@ -275,10 +168,6 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat
if self._pause_mission_graph:
return False

if update_features:
# Compute image features
self.update_features(node)

# Add image node
success = self._mission_graph.add_node(node)

Expand All @@ -297,17 +186,9 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat
node.supervision_mask = supervision_mask
node.update_supervision_signal()

if self._mode == WVNMode.EXTRACT_LABELS:
p = os.path.join(
self._extraction_store_folder,
"image",
str(node.timestamp).replace(".", "_") + ".pt",
)
torch.save(node.image, p)

return True
else:
return False

return False

@accumulate_time
@torch.no_grad()
Expand Down Expand Up @@ -357,7 +238,6 @@ def add_supervision_node(self, pnode: SupervisionNode):
if last_mission_node.timestamp - values["timestamp"] > 30:
node.clear_debug_data()
self._debug_info_node_count += 1
#length = len(list(self._mission_graph._graph.nodes._nodes.items()))
else:
break

Expand All @@ -367,7 +247,6 @@ def add_supervision_node(self, pnode: SupervisionNode):
)

if len(mission_nodes) < 1:

return False

# Set color
Expand All @@ -389,9 +268,7 @@ def add_supervision_node(self, pnode: SupervisionNode):
K[i] = mnode.image_projector.camera.intrinsics
pose_camera_in_world[i] = mnode.pose_cam_in_world

if (not hasattr(mnode, "supervision_mask")) or (mnode.supervision_mask is None):
continue
else:
if not ((not hasattr(mnode, "supervision_mask")) or (mnode.supervision_mask is None)):
supervision_masks[i] = mnode.supervision_mask

im = ImageProjector(K, H, W)
Expand All @@ -415,11 +292,6 @@ def add_supervision_node(self, pnode: SupervisionNode):
store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
torch.save(store, p)

# if self._anomaly_detection:
# # Visualize supervision mask
# store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
# self._last_image_mask_pub.publish(self._bridge.cv2_to_imgmsg(store.cpu().numpy().astype(np.uint8) * 255, "mono8"))

return True

def get_mission_nodes(self):
Expand All @@ -436,10 +308,6 @@ def get_last_valid_mission_node(self):
return last_valid_node

def get_mission_node_for_visualization(self):
# print(f"get_mission_node_for_visualization: {self._vis_mission_node}")
# if self._vis_mission_node is not None:
# print(f" has image {hasattr(self._vis_mission_node, 'image')}")
# print(f" has supervision_mask {hasattr(self._vis_mission_node, 'supervision_mask')}")
return self._vis_mission_node

def save(self, mission_path: str, filename: str):
Expand Down Expand Up @@ -598,18 +466,6 @@ def train(self):
graph, res, step=self._step, log_step=log_step
)

# Keep track of ROC during training for rescaling the loss when publishing
if self._scale_traversability:
# This mask should contain all the segments corrosponding to trees.
mask_anomaly = loss_aux["confidence"] < 0.5
mask_supervision = graph.y == 1
# Remove the segments that are for sure not an anomalies given that we have walked on them.
mask_anomaly[mask_supervision] = False
# Elements are valid if they are either an anomaly or we have walked on them to fit the ROC
mask_valid = mask_anomaly | mask_supervision
self._auxiliary_training_roc(res[mask_valid, 0], graph.y[mask_valid].type(torch.long))
return_dict["scale_traversability_threshold"] = self._scale_traversability_threshold

# Backprop
self._optimizer.zero_grad()
self._loss.backward()
Expand Down
2 changes: 0 additions & 2 deletions wild_visual_navigation_msgs/msg/SystemState.msg
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,3 @@ float32 loss_reco
uint32 step
# Pause learning flag
bool pause_learning
# Traversability thereshold scaling
float32 scale_traversability_threshold
2 changes: 0 additions & 2 deletions wild_visual_navigation_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,4 @@ catkin_install_python(PROGRAMS scripts/wvn_feature_extractor_node.py
scripts/overlay_images.py
scripts/smart_carrot.py
scripts/rosbag_play.sh
scripts/open_source_rosbag_converter/convert_to_public_format.py
scripts/open_source_rosbag_converter/encoding_fixer.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION})
Loading

0 comments on commit 6c73d70

Please sign in to comment.