Skip to content

Commit

Permalink
fix cache directory mount and argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Apr 14, 2024
1 parent e60fd10 commit b9e6243
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tracking_ros/run_docker
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import argparse
import os
import re
import shutil
import subprocess
Expand All @@ -9,16 +10,19 @@ from typing import Optional

_TRACKING_ROS_ROOT_INSIDE_CONTAINER = "/home/user/tracking_ws/src/tracking_ros"
_TORCH_CACHE_DIR_INSIDE_CONTAINER = "/home/user/.cache/torch"
_TORCH_CACHE_DIR = Path.home() / ".cache" / "torch"
_TORCH_CACHE_DIR = os.getenv("TORCH_HOME", Path.home() / ".cache" / "torch")
_HF_CACHE_DIR_INSIDE_CONTAINER = "/home/user/.cache/huggingface"
_HF_CACHE_DIR = Path.home() / ".cache" / "huggingface"
_HF_CACHE_DIR = os.getenv("HF_HOME", Path.home() / ".cache" / "huggingface")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-launch", type=str, help="launch file name")
parser.add_argument(
"-host", type=str, default="pr1040", help="host name or ip-address"
)
parser.add_argument(
"-cache", default=False, action="store_true", help="mount cache directory or not"
)
parser.add_argument(
"launch_args",
nargs=argparse.REMAINDER,
Expand Down Expand Up @@ -51,12 +55,20 @@ if __name__ == "__main__":
shutil.copytree(mount_path, tmp_launch_path)
else:
shutil.copyfile(mount_path, tmp_launch_path)

cache_option = """-v {torch_cache_dir}:{torch_cache_dir_inside_container} \
-v {hf_cache_dir}:{hf_cache_dir_inside_container} \
""".format(
torch_cache_dir=_TORCH_CACHE_DIR,
torch_cache_dir_inside_container=_TORCH_CACHE_DIR_INSIDE_CONTAINER,
hf_cache_dir=_HF_CACHE_DIR,
hf_cache_dir_inside_container=_HF_CACHE_DIR_INSIDE_CONTAINER,
) if args.cache else ""
docker_run_command = """
docker run \
-v {node_scripts_dir}:{tracking_ros_root}/node_scripts \
-v {tmp_launch_path}:{tracking_ros_root}/launch \
-v {torch_cache_dir}:{torch_cache_dir_inside_container} \
-v {hf_cache_dir}:{hf_cache_dir_inside_container} \
{cache_option} \
--rm --net=host -it \
{gpu_arg} \
tracking_ros:latest \
Expand All @@ -69,10 +81,7 @@ if __name__ == "__main__":
node_scripts_dir=Path(__file__).resolve().parent / "node_scripts",
tmp_launch_path=tmp_launch_path,
tracking_ros_root=_TRACKING_ROS_ROOT_INSIDE_CONTAINER,
torch_cache_dir=_TORCH_CACHE_DIR,
torch_cache_dir_inside_container=_TORCH_CACHE_DIR_INSIDE_CONTAINER,
hf_cache_dir=_HF_CACHE_DIR,
hf_cache_dir_inside_container=_HF_CACHE_DIR_INSIDE_CONTAINER,
cache_option=cache_option,
gpu_arg="--gpus 1" if use_gpu else "",
host=args.host,
launch_file_name=launch_file_name,
Expand Down

0 comments on commit b9e6243

Please sign in to comment.