Skip to content

Commit

Permalink
running on the robot
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 8, 2024
1 parent 8b7c6eb commit f06e079
Show file tree
Hide file tree
Showing 33 changed files with 309 additions and 144 deletions.
4 changes: 2 additions & 2 deletions scripts/dataset_generation/create_gnn_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import argparse
import cv2


from torch_geometric.data import Data
# TODO
# from torch_geometric.data import Data

if __name__ == "__main__":
"""Converts a folder with images to a torch_geometric dataformat.
Expand Down
4 changes: 3 additions & 1 deletion scripts/dataset_generation/extract_features_for_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
from pathlib import Path
import torch
from torch_geometric.data import Data

# TODO
# from torch_geometric.data import Data
from wild_visual_navigation.utils import KLTTrackerOpenCV
import numpy as np
from tqdm import tqdm
Expand Down
1 change: 1 addition & 0 deletions wild_visual_navigation/cfg/global_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ def get_global_env_params(name):
configs = {
"default": GlobalEnvironmentParams(perugia_root="TBD", results="results"),
"ge76": GlobalEnvironmentParams(perugia_root="TBD", results="results"),
"jetson": GlobalEnvironmentParams(perugia_root="TBD", results="results"),
}
return configs[name]
7 changes: 4 additions & 3 deletions wild_visual_navigation/dataset/graph_trav_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from torch_geometric.data import InMemoryDataset, DataLoader
from torch_geometric.data import LightningDataset
# TODO
# from torch_geometric.data import InMemoryDataset, DataLoader
# from torch_geometric.data import LightningDataset
# from torch_geometric.data import Dataset

from wild_visual_navigation import WVN_ROOT_DIR
import os
import torch
from torch_geometric.data import Dataset
from torchvision import transforms as T
from typing import Optional, Callable
import random
Expand Down
3 changes: 2 additions & 1 deletion wild_visual_navigation/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.nn.functional as F
from wild_visual_navigation.visu import LearningVisualizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch_geometric.data import Data

# from torch_geometric.data import Data
from torchmetrics import ROC

from wild_visual_navigation.utils import TraversabilityLoss, MetricLogger
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/model/linear_rnvp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import torch
from torch import nn, distributions
from torch_geometric.data import Data
from wild_visual_navigation.utils import Data


class LinearBatchNorm(nn.Module):
Expand Down
5 changes: 3 additions & 2 deletions wild_visual_navigation/model/simple_gcn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

# from torch_geometric.nn import GCNConv
from wild_visual_navigation.utils import Data


class SimpleGCN(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/model/simple_mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch_geometric.data import Data
from wild_visual_navigation.utils import Data


class SimpleMLP(torch.nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions wild_visual_navigation/traversability_estimator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
make_dense_plane,
)
from liegroups.torch import SE3, SO3
from torch_geometric.data import Data
from wild_visual_navigation.utils import Data

import os
import torch
from typing import Optional
Expand Down Expand Up @@ -166,10 +167,12 @@ def change_device(self, device):
"""
super().change_device(device)
self._image_projector.change_device(device)
self._image = self._image.to(device)

self._pose_cam_in_base = self._pose_cam_in_base.to(device)
self._pose_cam_in_world = self._pose_cam_in_world.to(device)

if self._image is not None:
self._image = self._image.to(device)
if self._features is not None:
self._features = self._features.to(device)
if self._feature_edges is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from wild_visual_navigation.visu import LearningVisualizer

from pytorch_lightning import seed_everything
from torch_geometric.data import Data, Batch
from wild_visual_navigation.utils import Data, Batch
from threading import Lock
import os
import pickle
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(
self._optimizer = torch.optim.Adam(self._model.parameters(), lr=self._params["optimizer"]["lr"])
self._loss = torch.tensor([torch.inf])
self._step = 0
self._debug_info_node_count = 0

torch.set_grad_enabled(True)

Expand Down Expand Up @@ -288,9 +289,9 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat
s += " " * (48 - len(s)) + f"total nodes [{total_nodes}]"
if verbose:
print(s)

h, w = node._feature_segments.shape[0], node._feature_segments.shape[1]
# Project past footprints on current image
supervision_mask = torch.ones(node.image.shape).to(self._device) * torch.nan
supervision_mask = torch.ones((3, h, w)).to(self._device) * torch.nan

# Finally overwrite the current mask
node.supervision_mask = supervision_mask
Expand Down Expand Up @@ -334,6 +335,7 @@ def add_supervision_node(self, pnode: SupervisionNode):
return False

else:

# If the previous node doesn't exist or it's invalid, we do nothing
if last_pnode is None or not last_pnode.is_valid():
return False
Expand All @@ -348,6 +350,18 @@ def add_supervision_node(self, pnode: SupervisionNode):
if (not hasattr(last_mission_node, "supervision_mask")) or (last_mission_node.supervision_mask is None):
return False

for j, ele in enumerate(
list(self._mission_graph._graph.nodes._nodes.items())[self._debug_info_node_count :]
):
node, values = ele
if last_mission_node.timestamp - values["timestamp"] > 20:
node.clear_debug_data()
self._debug_info_node_count += 1
length = len(list(self._mission_graph._graph.nodes._nodes.items()))
print(
f"cleaned node {self._debug_info_node_count} nodes {self._debug_info_node_count}, length {length}"
)

# Get all mission nodes within a range
mission_nodes = self._mission_graph.get_nodes_within_radius_range(
last_mission_node, 0, self._supervision_graph.max_distance
Expand Down
1 change: 1 addition & 0 deletions wild_visual_navigation/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .data import Data, Batch
from .flatten_dict import *
from .get_logger import get_logger, get_neptune_run
from .loading import load_yaml, file_path, save_omega_cfg
Expand Down
68 changes: 68 additions & 0 deletions wild_visual_navigation/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, List, Optional, Type, Union
from typing_extensions import Self
import torch


class Data:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)


class Batch:
def __init__(self):
pass

@classmethod
def from_data_list(cls, list_of_data: List[Data]) -> Self:
if len(list_of_data) == 0:
return None

base = ["x"]

tensors_to_concatenate = [
k for k in dir(list_of_data[0]) if k[0] != "_" and getattr(list_of_data[0], k) is not None and not k in base
]
base = base + tensors_to_concatenate

for k in base:
if k == "edge_index":
ls = []
for j, data in enumerate(list_of_data):
ls.append(getattr(data, k) + cls.ptr[j])

cls.edge_index = torch.cat(ls, dim=-1)
else:

if k == "x":
running = 0
ptrs = [running]
batches = []

for j, data in enumerate(list_of_data):
running = running + getattr(data, k).shape[0]
ptrs.append(running)
batches += [j] * int(getattr(data, k).shape[0])

cls.ptr = torch.tensor(ptrs, dtype=torch.long)
cls.batch = torch.tensor(batches, dtype=torch.long)

setattr(cls, k, torch.cat([getattr(data, k) for data in list_of_data], dim=0))

cls.ba = cls.x.shape[0]
return cls


if __name__ == "__main__":
from torch_geometric.data import Data as DataTorchGeometric
from torch_geometric.data import Batch as BatchTorchGeometric

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data_tg1 = DataTorchGeometric(x=x, edge_index=edge_index)
data1 = Data(x=x, edge_index=edge_index)
edge_index2 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x2 = torch.tensor([[-1], [12], [1]], dtype=torch.float)
data_tg2 = DataTorchGeometric(x=x2, edge_index=edge_index2)
data2 = Data(x=x2, edge_index=edge_index2)
batch = BatchTorchGeometric.from_data_list([data_tg1, data_tg2])
3 changes: 2 additions & 1 deletion wild_visual_navigation/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from wild_visual_navigation.utils import ConfidenceGenerator

import torch.nn.functional as F
from torch_geometric.data import Data
from wild_visual_navigation.utils import Data

import torch
from typing import Optional
from torch import nn
Expand Down
6 changes: 3 additions & 3 deletions wild_visual_navigation_anymal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ set (CMAKE_CXX_STANDARD 14)
set(CATKIN_PACKAGE_LIST
rospy
roscpp
anymal_msgs
sensor_msgs
std_msgs
wild_visual_navigation_msgs
Expand All @@ -30,9 +29,10 @@ include_directories(
)

if(BUILD_ANYMAL)
find_package( anymal_msgs )
# Declare node
add_executable(anymal_msg_converter_cpp_node
src/anymal_msg_converter_cpp_node.cpp)
add_executable( anymal_msg_converter_cpp_node
src/anymal_msg_converter_cpp_node.cpp)

target_link_libraries(anymal_msg_converter_cpp_node
${catkin_LIBRARIES}
Expand Down
5 changes: 3 additions & 2 deletions wild_visual_navigation_anymal/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
Setup on ANYmal:


NPC:
```
sudo apt-get install ros-noetic-anymal-msgs-dev
Expand All @@ -11,4 +10,6 @@ catkin build wild_visual_navigation_anymal --cmake-args -DBUILD_ANYMAL=1

Jetson:
```
```
```

Currently the resizing is not working which is pretty bad for the wide angle camera.
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ subscribers:
data_type: pointcloud

# Semantics
front_wide_angle:
topic_name: /wide_angle_camera_front/image_color_rect
camera_info_topic_name: /wide_angle_camera_front/camera_info
data_type: image

rear_wide_angle:
topic_name: /wide_angle_camera_rear/image_color_rect
camera_info_topic_name: /wide_angle_camera_rear/camera_info
data_type: image
# front_wide_angle:
# topic_name: /wide_angle_camera_front/image_color_rect
# camera_info_topic_name: /wide_angle_camera_front/camera_info
# data_type: image

# rear_wide_angle:
# topic_name: /wide_angle_camera_rear/image_color_rect
# camera_info_topic_name: /wide_angle_camera_rear/camera_info
# data_type: image

# Traversability
# channels: ["visual_traversability"]
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation_anymal/config/procman/replay.pmd
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ group "1 - WVN" {
}

cmd "1.2 - overlay" {
exec = "roslaunch wild_visual_navigation_ros overlay_images.launch";
exec = "roslaunch wild_visual_navigation_anymal overlay_images.launch";
host = "localhost";
}
cmd "1.3 - elevation_mapping_cupy" {
Expand Down
15 changes: 15 additions & 0 deletions wild_visual_navigation_anymal/config/recording/dodo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ npc:
jetson:
elevation_mapping:
- /elevation_mapping/elevation_map_raw
- /elevation_mapping/semantic_map
generic:
- /chrony_monitor_jetson/status
- /cpu_loupe_jetson/cpu_loupe
Expand All @@ -140,3 +141,17 @@ jetson:
hdr_camera:
- /hdr_camera/image_raw/compressed
- /hdr_camera/camera_info
wvn:
- /wild_visual_navigation_node/front/camera_info
- /wild_visual_navigation_node/front/confidence
- /wild_visual_navigation_node/front/feat
- /wild_visual_navigation_node/front/image_input
- /wild_visual_navigation_node/front/traversability
- /wild_visual_navigation_node/graph_footprints
- /wild_visual_navigation_node/graph_footprints_array
- /wild_visual_navigation_node/rear/camera_info
- /wild_visual_navigation_node/rear/confidence
- /wild_visual_navigation_node/rear/image_input
- /wild_visual_navigation_node/rear/traversability
- /wild_visual_navigation_node/robot_state
- /wild_visual_navigation_node/supervision_graph
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rosservice call /rosbag_record_coordinator/start_recording "yaml_file: '/home/jonfrey/git/wild_visual_navigation/wild_visual_navigation_ros/config/recording/dodo.yaml'"
rosservice call /rosbag_record_coordinator/start_recording "yaml_file: '/home/jonfrey/git/wild_visual_navigation/wild_visual_navigation_anymal/config/recording/dodo.yaml'"
Loading

0 comments on commit f06e079

Please sign in to comment.