Skip to content

Commit

Permalink
Fix scheduler, fix naming of private variables
Browse files Browse the repository at this point in the history
  • Loading branch information
mmattamala committed Feb 3, 2024
1 parent 6e0fb8b commit 9e3c01e
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 213 deletions.
1 change: 1 addition & 0 deletions tests/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_feature_extractor():
# Store results to test directory
img = get_img_from_fig(fig)
img.save(join(outpath, f"forest_clean_graph_{seg_type}_{feat_type}.png"))
plt.close()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class ModelParams:

@dataclass
class SimpleMlpCfgParams:
input_size: int = 384 # 90 for stego, 384 for dino
input_size: int = 90 # 90 for stego, 384 for dino
hidden_sizes: List[int] = field(default_factory=lambda: [256, 32, 1])
reconstruction: bool = True

Expand Down
2 changes: 2 additions & 0 deletions wild_visual_navigation/cfg/ros_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RosLearningNodeParams:
supervision_callback_rate: float # hertz
learning_thread_rate: float # hertz
logging_thread_rate: float # hertz
load_save_checkpoint_rate: float # hert

# Runtime options
device: str
Expand Down Expand Up @@ -97,3 +98,4 @@ class RosFeatureExtractorNodeParams:

# Threads
image_callback_rate: float # hertz
load_save_checkpoint_rate: float # hertz
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def adjacency_list(self, seg: torch.tensor):
Returns:
adjacency_list (torch.Tensor, dtype=torch.long, shape=(N, 2): Adjacency list of undirected graph
"""
assert seg.shape[0] == 1 and len(seg.shape) == 4
assert seg.shape[0] == 1 and len(seg.shape) == 4, f"{seg.shape}"

res = self.f1(seg.type(torch.float32))
boundary_mask = (res != 0)[0, :, 2:-2, 2:-2]
Expand Down
4 changes: 2 additions & 2 deletions wild_visual_navigation/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def load_test_image():
def get_dino_transform():
transform = T.Compose(
[
T.Resize(448, T.InterpolationMode.NEAREST),
T.CenterCrop(448),
T.Resize(224, T.InterpolationMode.NEAREST),
T.CenterCrop(224),
]
)
return transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ image_graph_dist_thr: 0.2 # meters
supervision_graph_dist_thr: 0.1 # meters
network_input_image_height: 224 # 448
network_input_image_width: 224 # 448
segmentation_type: "random" # Options: slic, grid, random, stego
feature_type: "dino" # Options: dino, dinov2, stego
segmentation_type: "stego" # Options: slic, grid, random, stego
feature_type: "stego" # Options: dino, dinov2, stego
dino_patch_size: 8 # 8 or 16; 8 is finer
dino_backbone: vit_small
slic_num_components: 100
Expand All @@ -46,6 +46,7 @@ supervision_callback_rate: 10 # hertz
learning_thread_rate: 10 # hertz
logging_thread_rate: 2 # hertz
status_thread_rate: 0.5 # hertz
load_save_checkpoint_rate: 0.2 # hertz, 1/0.2 = 5 sec equivalent

# Runtime options
device: "cuda"
Expand Down
15 changes: 12 additions & 3 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, node_name):

# Timers to control the rate of the subscriber
self._last_image_ts = rospy.get_time()
self._last_checkpoint_ts = rospy.get_time()

# Load model
self._model = get_model(self._params.model).to(self._ros_params.device)
Expand Down Expand Up @@ -280,7 +281,6 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
else:
if self._ros_params.verbose:
rospy.loginfo(f"[{self._node_name}] Image callback: {cam} -> Process")
self._scheduler.step()

self._last_image_ts = ts

Expand All @@ -292,7 +292,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
self._log_data[f"time_last_image_{cam}"] = rospy.get_time()

# Update model from file if possible
self.load_model()
self.load_model(image_msg.header.stamp)

# Convert image message to torch image
torch_image = rc.ros_image_to_torch(image_msg, device=self._ros_params.device)
Expand Down Expand Up @@ -414,12 +414,21 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
}
raise Exception("Error in image callback")

def load_model(self):
# Step scheduler
self._scheduler.step()

def load_model(self, stamp):
"""Method to load the new model weights to perform inference on the incoming images
Args:
None
"""
ts = stamp.to_sec()
if abs(ts - self._last_checkpoint_ts) < 1.0 / self._ros_params.load_save_checkpoint_rate:
return

self._last_checkpoint_ts = ts

try:
# self._load_model_counter += 1
# if self._load_model_counter % 10 == 0:
Expand Down
Loading

0 comments on commit 9e3c01e

Please sign in to comment.