Skip to content

Commit

Permalink
update cache method
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Jul 25, 2024
1 parent 570e787 commit 99af299
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
6 changes: 4 additions & 2 deletions src/lean_dojo/data_extraction/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ def get(self, url: str, commit: str) -> Optional[Path]:
else:
return None

def store(self, src: Path) -> Path:
def store(self, src: Path, fmt_name:str='') -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache."""
url, commit = get_repo_info(src)
dirpath = self.cache_dir / _format_dirname(url, commit)
if fmt_name == '': # if not specified, extract from the traced repo
fmt_name = _format_dirname(url, commit)
dirpath = self.cache_dir / fmt_name
_, repo_name = _split_git_url(url)
if not dirpath.exists():
with self.lock:
Expand Down
19 changes: 10 additions & 9 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
working_directory,
)
from ..constants import LEAN4_URL
from .cache import _split_git_url


GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None)
Expand Down Expand Up @@ -429,13 +430,7 @@ def __post_init__(self) -> None:
lean_version = self.commit
else:
config = self.get_config("lean-toolchain")
toolchain = config["content"]
m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip())
if m is not None:
lean_version = m["version"]
else:
# lean_version_commit = get_lean4_commit_from_config(config)
lean_version = get_lean4_version_from_config(toolchain)
lean_version = get_lean4_version_from_config(config["content"])
if not is_supported_version(lean_version):
logger.warning(
f"{self} relies on an unsupported Lean version: {lean_version}"
Expand All @@ -444,9 +439,9 @@ def __post_init__(self) -> None:
object.__setattr__(self, "lean_version", lean_version)

@classmethod
def from_path(cls, path: Path) -> "LeanGitRepo":
def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo":
"""Construct a :class:`LeanGitRepo` object from the path to a local Git repo."""
url, commit = get_repo_info(path)
url, commit = get_repo_info(Path(path))
return cls(url, commit)

@property
Expand All @@ -460,6 +455,12 @@ def is_lean4(self) -> bool:
@property
def commit_url(self) -> str:
return f"{self.url}/tree/{self.commit}"

@property
def format_dirname(self) -> str:
"""Return the formatted cache directory name"""
user_name, repo_name = _split_git_url(self.url)
return f"{user_name}-{repo_name}-{self.commit}"

def show(self) -> None:
"""Show the repo in the default browser."""
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
_trace(repo, build_deps)
traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps)
traced_repo.save_to_disk()
path = cache.store(tmp_dir / repo.name)
path = cache.store(tmp_dir / repo.name, repo.format_dirname)
else:
logger.debug("The traced repo is available in the cache.")
return path
Expand Down

0 comments on commit 99af299

Please sign in to comment.