Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend function features to allow new repo types #190

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/lean_dojo/data_extraction/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def __post_init__(self):
lock_path = self.cache_dir.with_suffix(".lock")
object.__setattr__(self, "lock", FileLock(lock_path))

def get(self, url: str, commit: str) -> Optional[Path]:
def get(self, url: str, commit: str, prefix: str = "") -> Optional[Path]:
"""Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found."""
_, repo_name = _split_git_url(url)
dirname = _format_dirname(url, commit)
dirpath = self.cache_dir / dirname
dirpath = self.cache_dir / prefix / dirname

with self.lock:
if dirpath.exists():
Expand All @@ -90,16 +90,20 @@ def get(self, url: str, commit: str) -> Optional[Path]:
else:
return None

def store(self, src: Path) -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache."""
url, commit = get_repo_info(src)
dirpath = self.cache_dir / _format_dirname(url, commit)
_, repo_name = _split_git_url(url)
def store(self, src: Path, rel_cache_dir: Path) -> Path:
"""Store a traced repo at path ``src``. Return its path in the cache.

Args:
src (Path): Path to the repo.
rel_cache_name (Path): The relative path of the stored repo in the cache.
"""
dirpath = self.cache_dir / rel_cache_dir.parent
cache_path = self.cache_dir / rel_cache_dir
if not dirpath.exists():
with self.lock:
with report_critical_failure(_CACHE_CORRPUTION_MSG):
shutil.copytree(src, dirpath / repo_name)
return dirpath / repo_name
shutil.copytree(src, cache_path)
return cache_path


cache = Cache(CACHE_DIR)
Expand Down
144 changes: 115 additions & 29 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@
from github.GithubException import GithubException
from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator
from git import Repo
from ..constants import TMP_DIR
import uuid
import shutil
from urllib.parse import urlparse


from ..utils import (
execute,
read_url,
url_exists,
get_repo_info,
working_directory,
is_git_repo,
)
from ..constants import LEAN4_URL
from .cache import _format_dirname


GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None)
Expand All @@ -45,24 +50,83 @@
)
GITHUB = Github()

LEAN4_REPO = GITHUB.get_repo("leanprover/lean4")
LEAN4_REPO = None
"""The GitHub Repo for Lean 4 itself."""

_URL_REGEX = re.compile(r"(?P<url>.*?)/*")

_SSH_TO_HTTPS_REGEX = re.compile(r"^git@github\.com:(.+)/(.+)(?:\.git)?$")

def normalize_url(url: str) -> str:

def normalize_url(url: str, repo_type: str = "github") -> str:
if repo_type == "local":
return os.path.abspath(url) # Convert to absolute path if local
return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`.


def repo_type_of_url(url: str) -> Union[str, None]:
"""Get the type of the repository.

Args:
url (str): The URL of the repository.
Returns:
str: The type of the repository.
"""
m = _SSH_TO_HTTPS_REGEX.match(url)
url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url
parsed_url = urlparse(url)
if parsed_url.scheme in ["http", "https"]:
# case 1 - GitHub URL
if "github.com" in url:
if not url.startswith("https://"):
logger.warning(f"{url} should start with https://")
return
else:
return "github"
# case 2 - remote URL
elif url_exists(url): # not check whether it is a git URL
return "remote"
# case 3 - local path
elif is_git_repo(Path(parsed_url.path)):
return "local"
logger.warning(f"{url} is not a valid URL")
return None


@cache
def url_to_repo(url: str, num_retries: int = 2) -> Repository:
def url_to_repo(
url: str,
num_retries: int = 2,
repo_type: Union[str, None] = None,
tmp_dir: Union[Path] = None,
) -> Union[Repo, Repository]:
"""Convert a URL to a Repo object.

Args:
url (str): The URL of the repository.
num_retries (int): Number of retries in case of failure.
repo_type (Optional[str]): The type of the repository. Defaults to None.
tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None.

Returns:
Repo: A Git Repo object.
"""
url = normalize_url(url)
backoff = 1

tmp_dir = tmp_dir or os.path.join(TMP_DIR or "/tmp", str(uuid.uuid4())[:8])
repo_type = repo_type or repo_type_of_url(url)
while True:
try:
return GITHUB.get_repo("/".join(url.split("/")[-2:]))
if repo_type == "github":
return GITHUB.get_repo("/".join(url.split("/")[-2:]))
with working_directory(tmp_dir):
repo_name = os.path.basename(url)
if repo_type == "local":
assert is_git_repo(url), f"Local path {url} is not a git repo"
shutil.copytree(url, repo_name)
return Repo(repo_name)
else:
return Repo.clone_from(url, repo_name)
except Exception as ex:
if num_retries <= 0:
raise ex
Expand All @@ -76,29 +140,39 @@ def url_to_repo(url: str, num_retries: int = 2) -> Repository:
def get_latest_commit(url: str) -> str:
"""Get the hash of the latest commit of the Git repo at ``url``."""
repo = url_to_repo(url)
return repo.get_branch(repo.default_branch).commit.sha
if isinstance(repo, Repository):
return repo.get_branch(repo.default_branch).commit.sha
else:
return repo.head.commit.hexsha


def cleanse_string(s: Union[str, Path]) -> str:
"""Replace : and / with _ in a string."""
return str(s).replace("/", "_").replace(":", "_")


@cache
def _to_commit_hash(repo: Repository, label: str) -> str:
def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str:
"""Convert a tag or branch to a commit hash."""
logger.debug(f"Querying the commit hash for {repo.name} {label}")

try:
return repo.get_branch(label).commit.sha
except GithubException:
pass

for tag in repo.get_tags():
if tag.name == label:
return tag.commit.sha

raise ValueError(f"Invalid tag or branch: `{label}` for {repo}")
# GitHub repository
if isinstance(repo, Repository):
logger.debug(f"Querying the commit hash for {repo.name} {label}")
try:
commit = repo.get_commit(label).sha
except GithubException as e:
raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}")
# Local or remote Git repository
elif isinstance(repo, Repo):
logger.debug(
f"Querying the commit hash for {repo.working_dir} repository {label}"
)
try:
# Resolve the label to a commit hash
commit = repo.commit(label).hexsha
except Exception as e:
raise ValueError(f"Error converting ref to commit hash: {e}")
else:
raise TypeError("Unsupported repository type")
return commit


@dataclass(eq=True, unsafe_hash=True)
Expand Down Expand Up @@ -320,6 +394,11 @@ def __getitem__(self, key) -> str:
_LEAN4_VERSION_REGEX = re.compile(r"leanprover/lean4:(?P<version>.+?)")


def is_commit_hash(s: str):
"""Check if a string is a valid commit hash."""
return len(s) == 40 and _COMMIT_REGEX.fullmatch(s)


def get_lean4_version_from_config(toolchain: str) -> str:
"""Return the required Lean version given a ``lean-toolchain`` config."""
m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip())
Expand All @@ -330,6 +409,9 @@ def get_lean4_version_from_config(toolchain: str) -> str:
def get_lean4_commit_from_config(config_dict: Dict[str, Any]) -> str:
"""Return the required Lean commit given a ``lean-toolchain`` config."""
assert "content" in config_dict, "config_dict must have a 'content' field"
global LEAN4_REPO
if LEAN4_REPO is None:
LEAN4_REPO = GITHUB.get_repo("leanprover/lean4")
config = config_dict["content"].strip()
prefix = "leanprover/lean4:"
assert config.startswith(prefix), f"Invalid Lean 4 version: {config}"
Expand Down Expand Up @@ -416,12 +498,12 @@ def __post_init__(self) -> None:
object.__setattr__(self, "repo", url_to_repo(self.url))

# Convert tags or branches to commit hashes
if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)):
if not is_commit_hash(self.commit):
if (self.url, self.commit) in info_cache.tag2commit:
commit = info_cache.tag2commit[(self.url, self.commit)]
else:
commit = _to_commit_hash(self.repo, self.commit)
assert _COMMIT_REGEX.fullmatch(commit), f"Invalid commit hash: {commit}"
assert is_commit_hash(commit)
info_cache.tag2commit[(self.url, self.commit)] = commit
object.__setattr__(self, "commit", commit)

Expand All @@ -432,24 +514,23 @@ def __post_init__(self) -> None:
lean_version = self.commit
else:
config = self.get_config("lean-toolchain")
lean_version = get_lean4_commit_from_config(config)
v = get_lean4_version_from_config(config["content"])
if not is_supported_version(v):
lean_version = get_lean4_version_from_config(config["content"])
if not is_supported_version(lean_version):
logger.warning(
f"{self} relies on an unsupported Lean version: {lean_version}"
)
info_cache.lean_version[(self.url, self.commit)] = lean_version
object.__setattr__(self, "lean_version", lean_version)

@classmethod
def from_path(cls, path: Path) -> "LeanGitRepo":
def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo":
"""Construct a :class:`LeanGitRepo` object from the path to a local Git repo."""
url, commit = get_repo_info(path)
url, commit = get_repo_info(Path(path))
return cls(url, commit)

@property
def name(self) -> str:
return self.repo.name
return os.path.basename(self.url)

@property
def is_lean4(self) -> bool:
Expand All @@ -459,6 +540,11 @@ def is_lean4(self) -> bool:
def commit_url(self) -> str:
return f"{self.url}/tree/{self.commit}"

@property
def format_dirname(self) -> str:
"""Return the formatted cache directory name"""
return _format_dirname(self.url, self.commit)

def show(self) -> None:
"""Show the repo in the default browser."""
webbrowser.open(self.commit_url)
Expand Down
4 changes: 3 additions & 1 deletion src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
_trace(repo, build_deps)
traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps)
traced_repo.save_to_disk()
path = cache.store(tmp_dir / repo.name)
src_dir = tmp_dir / repo.name
rel_cache_dir = Path(repo.format_dirname) / repo.name
path = cache.store(src_dir, rel_cache_dir)
else:
logger.debug("The traced repo is available in the cache.")
return path
Expand Down
8 changes: 6 additions & 2 deletions src/lean_dojo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,13 @@ def read_url(url: str, num_retries: int = 2) -> str:

@cache
def url_exists(url: str) -> bool:
"""Return True if the URL ``url`` exists."""
"""Return True if the URL ``url`` exists, using the GITHUB_ACCESS_TOKEN for authentication if provided."""
try:
with urllib.request.urlopen(url) as _:
request = urllib.request.Request(url)
gh_token = os.getenv("GITHUB_ACCESS_TOKEN")
if gh_token is not None:
request.add_header("Authorization", f"token {gh_token}")
with urllib.request.urlopen(request) as _:
Comment on lines 213 to +218
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repo.exists call in the trace function is taking a long time due to the absence of GitHub token here

def _trace(repo: LeanGitRepo, build_deps: bool) -> None:
    assert (
        repo.exists()
    ), f"The {repo} does not exist. Please check the URL `{repo.commit_url}`."

return True
except urllib.error.HTTPError:
return False
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4"
LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example"
EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b"
REMOTE_EXAMPLE_URL = "https://gitee.com/rexzong/lean4-example"
URLS = [
BATTERIES_URL,
AESOP_URL,
Expand All @@ -16,6 +17,11 @@
]


@pytest.fixture(scope="session")
def remote_example_url():
return REMOTE_EXAMPLE_URL


@pytest.fixture(scope="session")
def example_commit_hash():
return EXAMPLE_COMMIT_HASH
Expand Down
55 changes: 55 additions & 0 deletions tests/data_extraction/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# test for cache manager
from git import Repo
from lean_dojo.utils import working_directory
from pathlib import Path
from lean_dojo.data_extraction.lean import _format_dirname
from lean_dojo.data_extraction.cache import cache


def test_get_cache(lean4_example_url, remote_example_url, example_commit_hash):
# Note: The `git.Repo` requires the local repo to be cloned in a directory
# all cached repos are stored in CACHE_DIR/repos
prefix = "repos"

# test local repo cache
with working_directory() as tmp_dir:
# assume that the local repo placed in `/.../testrepo/lean4-example`
repo = Repo.clone_from(lean4_example_url, "testrepo/lean4-example")
repo.git.checkout(example_commit_hash)
local_dir = tmp_dir / "testrepo/lean4-example"
# use local_dir as the key to store the cache
rel_cache_dir = (
prefix
/ Path(_format_dirname(str(local_dir), example_commit_hash))
/ local_dir.name
)
cache.store(local_dir, rel_cache_dir)
# get the cache
local_url, local_commit = str(local_dir), example_commit_hash
repo_cache = cache.get(local_url, local_commit, prefix)
assert (
_format_dirname(local_url, local_commit)
== f"{local_dir.parent.name}-{local_dir.name}-{local_commit}"
)
assert repo_cache is not None

# test remote repo cache
with working_directory() as tmp_dir:
repo = Repo.clone_from(remote_example_url, "lean4-example")
repo.git.checkout(example_commit_hash)
tmp_remote_dir = tmp_dir / "lean4-example"
# use remote url as the key to store the cache
rel_cache_dir = (
prefix
/ Path(_format_dirname(str(remote_example_url), example_commit_hash))
/ tmp_remote_dir.name
)
cache.store(tmp_remote_dir, rel_cache_dir)
# get the cache
remote_url, remote_commit = remote_example_url, example_commit_hash
repo_cache = cache.get(remote_url, remote_commit, prefix)
assert repo_cache is not None
assert (
_format_dirname(remote_url, remote_commit)
== f"rexzong-lean4-example-{example_commit_hash}"
)
Loading