Skip to content

Commit

Permalink
Minor changes in postprocess logger
Browse files Browse the repository at this point in the history
  • Loading branch information
mmattamala committed Feb 8, 2024
1 parent a9518b7 commit 9e3b3ae
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# for offline analysis, such as images and learning curves

from sensor_msgs.msg import Image
from rosgraph_msgs.msg import Clock
from wild_visual_navigation_msgs.msg import SystemState
from cv_bridge import CvBridge

Expand All @@ -13,7 +12,6 @@
import os
import rospy
import rospkg
import time
import yaml


Expand All @@ -37,14 +35,29 @@ def __init__(self):
# Initialize variables
self._bridge = CvBridge()

package_path = rospkg.RosPack().get_path("wild_visual_navigation_ros")
base_output_path = os.path.join(package_path, "output")

# Remove old checkpoint
state_dict = os.path.join(package_path, "../.tmp_state_dict.pt")
if os.path.exists(state_dict):
os.remove(state_dict)

# Initialize log folder
rospy.wait_for_message("/clock", Clock)
stamp = time.localtime(rospy.get_time())
mission_name = f"{self._mission_name}_{time.strftime('%Y%m%d_%H%M%S', stamp)}"
run = 0
while True:
mission_name = f"{self._mission_name}_{str(run).zfill(2)}"
self._output_path = os.path.join(base_output_path, mission_name)

package_path = rospkg.RosPack().get_path("wild_visual_navigation_ros")
self._output_path = os.path.join(package_path, "output", mission_name)
os.makedirs(self._output_path, exist_ok=True)
if os.path.exists(self._output_path):
run += 1
else:
break

# Make folder
mission_name = f"{self._mission_name}_{str(run).zfill(2)}"
self._output_path = os.path.join(base_output_path, mission_name)
os.makedirs(self._output_path, exist_ok=False)

# Initialize CSV writer
csv_file = open(f"{self._output_path}/wvn_state.csv", "w")
Expand Down Expand Up @@ -104,7 +117,7 @@ def _state_callback(self, system_state_msg):
stamp = rospy.get_rostime()
self._csv_writer.writerow(
[
f"{stamp.secs}.{stamp.nsecs}",
f"{secs_to_str(stamp.secs)}.{nsecs_to_str(stamp.nsecs)}",
system_state_msg.mode,
system_state_msg.mission_graph_num_valid_node,
system_state_msg.loss_total,
Expand Down
5 changes: 4 additions & 1 deletion wild_visual_navigation_ros/scripts/wvn_learning_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,10 @@ def learning_thread_loop(self):
# Check the rate
ts = rospy.get_time()
if abs(ts - self._last_checkpoint_ts) > 1.0 / self._ros_params.load_save_checkpoint_rate:
os.remove(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt")
try:
os.remove(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt")
except Exception:
pass
torch.save(new_model_state_dict, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt")
self._last_checkpoint_ts = ts

Expand Down

0 comments on commit 9e3b3ae

Please sign in to comment.