diff --git a/.gitignore b/.gitignore index cd92677..8d89e22 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Pyre type checker .pyre/ + +# vscode debug config +.vscode/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4625fc2..d6c7061 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "python-dotenv", "loguru", "filelock", + "gitpython", "psutil", "pexpect", "types-psutil", diff --git a/src/lean_dojo/constants.py b/src/lean_dojo/constants.py index aa50261..fd3f74b 100644 --- a/src/lean_dojo/constants.py +++ b/src/lean_dojo/constants.py @@ -71,20 +71,28 @@ assert re.fullmatch(r"\d+g", TACTIC_MEMORY_LIMIT) -def check_git_version(min_version: Tuple[int, int, int]) -> Tuple[int, int, int]: +def check_git_version(min_version: Tuple[int, int, int]) -> None: """Check the version of Git installed on the system.""" - res = subprocess.run("git --version", shell=True, capture_output=True, check=True) - output = res.stdout.decode() - error = res.stderr.decode() - assert error == "", error - m = re.match(r"git version (?P[0-9.]+)", output) - version = tuple(int(_) for _ in m["version"].split(".")) - - version_str = ".".join(str(_) for _ in version) - min_version_str = ".".join(str(_) for _ in min_version) - assert ( - version >= min_version - ), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}." + try: + res = subprocess.run( + "git --version", shell=True, capture_output=True, check=True + ) + output = res.stdout.decode().strip() + error = res.stderr.decode() + assert error == "", error + match = re.search(r"git version (\d+\.\d+\.\d+)", output) + if not match: + raise ValueError("Could not parse Git version from the output.") + # Convert version number string to tuple of integers + version = tuple(int(_) for _ in match.group(1).split(".")) + + version_str = ".".join(str(_) for _ in version) + min_version_str = ".".join(str(_) for _ in min_version) + assert ( + version >= min_version + ), f"Git version {version_str} is too old. Please upgrade to at least {min_version_str}." + except subprocess.CalledProcessError as e: + raise Exception(f"Failed to run git command: {e}") check_git_version((2, 25, 0)) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 427b531..0daed9e 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -17,6 +17,8 @@ from github.Repository import Repository from github.GithubException import GithubException from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator +from git import Repo + from ..utils import ( execute, @@ -457,7 +459,7 @@ def is_lean4(self) -> bool: @property def commit_url(self) -> str: - return os.path.join(self.url, f"tree/{self.commit}") + return f"{self.url}/tree/{self.commit}" def show(self) -> None: """Show the repo in the default browser.""" @@ -469,12 +471,9 @@ def exists(self) -> bool: def clone_and_checkout(self) -> None: """Clone the repo to the current working directory and checkout a specific commit.""" logger.debug(f"Cloning {self}") - execute(f"git clone -n --recursive {self.url}", capture_output=True) - with working_directory(self.name): - execute( - f"git checkout {self.commit} && git submodule update --recursive", - capture_output=True, - ) + repo = Repo.clone_from(self.url, Path(self.name), no_checkout=True) + repo.git.checkout(self.commit) + repo.submodule_update(init=True, recursive=True) def get_dependencies( self, path: Union[str, Path, None] = None diff --git a/tests/conftest.py b/tests/conftest.py index 1e14c96..7113916 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ AESOP_URL = "https://github.com/leanprover-community/aesop" MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4" LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example" +EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b" URLS = [ BATTERIES_URL, AESOP_URL, @@ -15,6 +16,16 @@ ] +@pytest.fixture(scope="session") +def example_commit_hash(): + return EXAMPLE_COMMIT_HASH + + +@pytest.fixture(scope="session") +def lean4_example_url(): + return LEAN4_EXAMPLE_URL + + @pytest.fixture(scope="session") def monkeysession(): with pytest.MonkeyPatch.context() as mp: diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py new file mode 100644 index 0000000..005218c --- /dev/null +++ b/tests/data_extraction/test_lean_repo.py @@ -0,0 +1,11 @@ +# test for the class `LeanGitRepo` +from lean_dojo import LeanGitRepo + + +def test_lean_git_repo(lean4_example_url, example_commit_hash): + repo = LeanGitRepo(lean4_example_url, example_commit_hash) + assert repo.url == lean4_example_url + assert repo.commit == example_commit_hash + assert repo.exists() + assert repo.name == "lean4-example" + assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py new file mode 100644 index 0000000..1e0b858 --- /dev/null +++ b/tests/interaction/test_interaction.py @@ -0,0 +1,17 @@ +from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem + + +def test_remote_interact(lean4_example_url): + repo = LeanGitRepo(url=lean4_example_url, commit="main") + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished)