From 172a885dcc00bf719e799976639cc7bda2f90be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filiph=20Siitam=20Sandstr=C3=B6m?= Date: Wed, 13 Dec 2023 20:36:35 +0000 Subject: [PATCH] Add support for `sweep.yml` as an alternative to the default `sweep.yaml` This is the preferred file extension outside of the python ecosystem. Preferably it would be entierly customizable but that would be a much larger refector and out of scope for this commit. --- docs/components/PRPreview.jsx | 2 +- docs/pages/usage/advanced.mdx | 2 +- docs/pages/usage/config.mdx | 2 +- docs/pages/usage/sandbox.mdx | 2 +- sandbox/cli.py | 5 ++++ sandbox/src/sandbox_utils.py | 10 ++++++-- sweepai/api.py | 2 ++ sweepai/config/client.py | 40 ++++++++++++++++++++++++++--- sweepai/handlers/create_pr.py | 9 +++++-- sweepai/handlers/on_ticket.py | 2 +- tests/archive/test_diff_parsing3.py | 4 +-- tests/archive/test_match.py | 4 +-- tests/archive/test_naive_chunker.py | 4 +-- tests/search/test_lexical_search.py | 4 +-- 14 files changed, 72 insertions(+), 20 deletions(-) diff --git a/docs/components/PRPreview.jsx b/docs/components/PRPreview.jsx index b6411c288f..fc28096f2f 100644 --- a/docs/components/PRPreview.jsx +++ b/docs/components/PRPreview.jsx @@ -163,7 +163,7 @@ export function PRPreview({ repoName, prId }) { }} > {parsedDiff.map(({chunks, from, oldStart}) => ( - from !== "/dev/null" && from !== "sweep.yaml" && + from !== "/dev/null" && (from !== "sweep.yaml" || from !== "sweep.yml") && <>

Template `sweep.yaml` to copy diff --git a/sandbox/cli.py b/sandbox/cli.py index c363534222..3e4b1ffc07 100644 --- a/sandbox/cli.py +++ b/sandbox/cli.py @@ -51,6 +51,8 @@ def from_yaml(cls, yaml_string: str): def from_config(cls, path: str = "sweep.yaml"): if os.path.exists(path): return cls.from_yaml(open(path).read()) + elif os.path.exists(path.replace(".yaml", ".yml")): + return cls.from_yaml(open(path.replace(".yaml", ".yml")).read()) else: return cls() @@ -101,6 +103,9 @@ def get_sandbox_from_config(): if os.path.exists("sweep.yaml"): config = yaml.load(open("sweep.yaml", "r"), Loader=yaml.FullLoader) return Sandbox(**config.get("sandbox", {})) + elif os.path.exists("sweep.yml"): + config = yaml.load(open("sweep.yml", "r"), Loader=yaml.FullLoader) + return Sandbox(**config.get("sandbox", {})) else: return Sandbox() diff --git a/sandbox/src/sandbox_utils.py b/sandbox/src/sandbox_utils.py index ced989c62f..d62912c8ba 100644 --- a/sandbox/src/sandbox_utils.py +++ b/sandbox/src/sandbox_utils.py @@ -90,13 +90,19 @@ def from_yaml(cls, yaml_string: str): def from_config(cls, path: str = "sweep.yaml"): if os.path.exists(path): return cls.from_yaml(open(path).read()) + elif os.path.exists(path.replace(".yaml", ".yml")): + return cls.from_yaml(open(path.replace(".yaml", ".yml")).read()) else: return cls() @classmethod def from_directory(cls, path: str): - if os.path.exists(os.path.join(path, "sweep.yaml")): - sandbox = cls.from_yaml(open(os.path.join(path, "sweep.yaml")).read()) + if os.path.exists(os.path.join(path, "sweep.yaml")) or os.path.exists(os.path.join(path, "sweep.yml")): + if os.path.exists(os.path.join(path, "sweep.yaml")): + sandbox = cls.from_yaml(open(os.path.join(path, "sweep.yaml")).read()) + else: + sandbox = cls.from_yaml(open(os.path.join(path, "sweep.yml")).read()) + is_default_sandbox = True if sandbox.install != ["trunk init"]: is_default_sandbox = False diff --git a/sweepai/api.py b/sweepai/api.py index ea9d2a3749..d56da5fd3a 100644 --- a/sweepai/api.py +++ b/sweepai/api.py @@ -769,7 +769,9 @@ def remove_buttons_from_description(body): ): if request_dict["head_commit"] and ( "sweep.yaml" in request_dict["head_commit"]["added"] + or "sweep.yml" in request_dict["head_commit"]["added"] or "sweep.yaml" in request_dict["head_commit"]["modified"] + or "sweep.yml" in request_dict["head_commit"]["modified"] ): _, g = get_github_client(request_dict["installation"]["id"]) repo = g.get_repo(request_dict["repository"]["full_name"]) diff --git a/sweepai/config/client.py b/sweepai/config/client.py index 069aa43f26..9fbff4797c 100644 --- a/sweepai/config/client.py +++ b/sweepai/config/client.py @@ -90,6 +90,7 @@ class SweepConfig(BaseModel): ".pem", ".ttf", "sweep.yaml", + "sweep.yml" ] # Image formats max_file_limit: int = 60_000 @@ -122,6 +123,9 @@ def get_branch(repo: Repository, override_branch: str | None = None) -> str: sweep_yaml_dict = {} try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + sweep_yaml_dict = yaml.safe_load( contents.decoded_content.decode("utf-8") ) @@ -153,6 +157,9 @@ def get_branch(repo: Repository, override_branch: str | None = None) -> str: def get_config(repo: Repository): try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + config = yaml.safe_load(contents.decoded_content.decode("utf-8")) return SweepConfig(**config) except SystemExit: @@ -167,6 +174,9 @@ def get_config(repo: Repository): def get_draft(repo: Repository): try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + config = yaml.safe_load(contents.decoded_content.decode("utf-8")) return config.get("draft", False) except SystemExit: @@ -180,6 +190,9 @@ def get_draft(repo: Repository): def get_gha_enabled(repo: Repository) -> bool: try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + gha_enabled = yaml.safe_load(contents.decoded_content.decode("utf-8")).get( "gha_enabled", True ) @@ -197,6 +210,9 @@ def get_gha_enabled(repo: Repository) -> bool: def get_description(repo: Repository) -> dict: try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + sweep_yaml = yaml.safe_load(contents.decoded_content.decode("utf-8")) description = sweep_yaml.get("description", "") rules = sweep_yaml.get("rules", []) @@ -212,6 +228,9 @@ def get_description(repo: Repository) -> dict: def get_sandbox_config(repo: Repository): try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + description = yaml.safe_load(contents.decoded_content.decode("utf-8")).get( "sandbox", {} ) @@ -226,6 +245,9 @@ def get_sandbox_config(repo: Repository): def get_branch_name_config(repo: Repository): try: contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + description = yaml.safe_load(contents.decoded_content.decode("utf-8")).get( "branch_use_underscores", False ) @@ -239,7 +261,11 @@ def get_branch_name_config(repo: Repository): @lru_cache(maxsize=None) def get_documentation_dict(repo: Repository): try: - sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode( + contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + + sweep_yaml_content = contents.decoded_content.decode( "utf-8" ) sweep_yaml = yaml.safe_load(sweep_yaml_content) @@ -254,7 +280,11 @@ def get_documentation_dict(repo: Repository): @lru_cache(maxsize=None) def get_blocked_dirs(repo: Repository): try: - sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode( + contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + + sweep_yaml_content = contents.decoded_content.decode( "utf-8" ) sweep_yaml = yaml.safe_load(sweep_yaml_content) @@ -269,7 +299,11 @@ def get_blocked_dirs(repo: Repository): @lru_cache(maxsize=None) def get_rules(repo: Repository): try: - sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode( + contents = repo.get_contents("sweep.yaml") + if contents is None: + contents = repo.get_contents("sweep.yml") + + sweep_yaml_content = contents.decoded_content.decode( "utf-8" ) sweep_yaml = yaml.safe_load(sweep_yaml_content) diff --git a/sweepai/handlers/create_pr.py b/sweepai/handlers/create_pr.py index fb3a200f8b..5e06a493e8 100644 --- a/sweepai/handlers/create_pr.py +++ b/sweepai/handlers/create_pr.py @@ -141,7 +141,7 @@ def create_pr_changes( else: pr_description = f"{pull_request.content}" pr_title = pull_request.title - if "sweep.yaml" in pr_title: + if "sweep.yaml" in pr_title or "sweep.yml" in pr_title: pr_title = "[config] " + pr_title except MaxTokensExceeded as e: logger.error(e) @@ -237,7 +237,12 @@ def create_config_pr( except SystemExit: raise SystemExit except Exception: - pass + try: + repo.get_contents("sweep.yml") + except SystemExit: + raise SystemExit + except Exception: + pass title = "Configure Sweep" branch_name = GITHUB_CONFIG_BRANCH diff --git a/sweepai/handlers/on_ticket.py b/sweepai/handlers/on_ticket.py index a58dec19d2..fd028eb390 100644 --- a/sweepai/handlers/on_ticket.py +++ b/sweepai/handlers/on_ticket.py @@ -814,7 +814,7 @@ def edit_sweep_comment(message: str, index: int, pr_message="", done=False): sweep_yml_exists = False sweep_yml_failed = False for content_file in repo.get_contents(""): - if content_file.name == "sweep.yaml": + if content_file.name == "sweep.yaml" or content_file.name == "sweep.yml": sweep_yml_exists = True # Check if YAML is valid diff --git a/tests/archive/test_diff_parsing3.py b/tests/archive/test_diff_parsing3.py index 13269759fd..7d153a0add 100644 --- a/tests/archive/test_diff_parsing3.py +++ b/tests/archive/test_diff_parsing3.py @@ -670,10 +670,10 @@ def log_error(error_type, exception, priority=0): sweep_context=sweep_context, ) - # Check repository for sweep.yml file. + # Check repository for sweep.yaml/sweep.yml file. sweep_yml_exists = False for content_file in repo.get_contents(""): - if content_file.name == "sweep.yaml": + if content_file.name == "sweep.yaml" or content_file.name == "sweep.yml": sweep_yml_exists = True break diff --git a/tests/archive/test_match.py b/tests/archive/test_match.py index be3feaeaf6..9830e3aa5e 100644 --- a/tests/archive/test_match.py +++ b/tests/archive/test_match.py @@ -587,10 +587,10 @@ def edit_sweep_comment(message: str, index: int, pr_message="", done=False): cloned_repo=cloned_repo, ) - # Check repository for sweep.yml file. + # Check repository for sweep.yaml/sweep.yml file. sweep_yml_exists = False for content_file in repo.get_contents(""): - if content_file.name == "sweep.yaml": + if content_file.name == "sweep.yaml" or content_file.name == "sweep.yml": sweep_yml_exists = True break diff --git a/tests/archive/test_naive_chunker.py b/tests/archive/test_naive_chunker.py index fdf1112811..0cc95cd8d5 100644 --- a/tests/archive/test_naive_chunker.py +++ b/tests/archive/test_naive_chunker.py @@ -579,10 +579,10 @@ def edit_sweep_comment(message: str, index: int, pr_message="", done=False): sweep_context=sweep_context, ) - # Check repository for sweep.yml file. + # Check repository for sweep.yaml/sweep.yml file. sweep_yml_exists = False for content_file in repo.get_contents(""): - if content_file.name == "sweep.yaml": + if content_file.name == "sweep.yaml" or content_file.name == "sweep.yml": sweep_yml_exists = True break diff --git a/tests/search/test_lexical_search.py b/tests/search/test_lexical_search.py index ae693623cd..1b96ea9d77 100644 --- a/tests/search/test_lexical_search.py +++ b/tests/search/test_lexical_search.py @@ -581,10 +581,10 @@ def edit_sweep_comment(message: str, index: int, pr_message="", done=False): sweep_context=sweep_context, ) - # Check repository for sweep.yml file. + # Check repository for sweep.yaml/sweep.yml file. sweep_yml_exists = False for content_file in repo.get_contents(""): - if content_file.name == "sweep.yaml": + if content_file.name == "sweep.yaml" or content_file.name == "sweep.yml": sweep_yml_exists = True break