diff --git a/.ci/scripts/setup-macos.sh b/.ci/scripts/setup-macos.sh
index 30889fd397..aee0128461 100755
--- a/.ci/scripts/setup-macos.sh
+++ b/.ci/scripts/setup-macos.sh
@@ -104,6 +104,12 @@ print_cmake_info() {
codesign -f -s - "${CMAKE_EXEC}" || true
}
+setup_macos_env_variables() {
+ CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
+ export CMAKE_PREFIX_PATH
+}
+
+setup_macos_env_variables
# NB: we need buck2 in all cases because cmake build also depends on calling
# buck2 atm
install_buck
diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh
index 06444785af..558e2aaccc 100644
--- a/.ci/scripts/test_llama.sh
+++ b/.ci/scripts/test_llama.sh
@@ -12,7 +12,11 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
MODEL_NAME=$1 # stories110M.pt
BUILD_TOOL=$2 # buck2 or cmake
DTYPE=$3 # fp16 or fp32
-
+MODE=${4:-"xnnpack"} # portable or xnnpack
+if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args
+ echo "Expecting atleast 4 positional arguments"
+ echo "Usage: [...]"
+fi
if [[ -z "${MODEL_NAME:-}" ]]; then
echo "Missing model name, exiting..."
exit 1
@@ -28,6 +32,11 @@ if [[ -z "${DTYPE:-}" ]]; then
exit 1
fi
+if [[ -z "${MODE:-}" ]]; then
+ echo "Missing mode, choose portable or xnnpack, exiting..."
+ exit 1
+fi
+
if [[ -z "${BUCK:-}" ]]; then
BUCK=buck2
fi
@@ -42,12 +51,18 @@ which "${PYTHON_EXECUTABLE}"
cmake_install_executorch_libraries() {
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
rm -rf cmake-out
+ if [[ "${MODE}" == "xnnpack" ]]; then
+ XNNPACK=ON
+ else
+ XNNPACK=OFF
+ fi
retry cmake -DBUCK2="$BUCK" \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
+ -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-Bcmake-out .
cmake --build cmake-out -j9 --target install --config Release
@@ -101,7 +116,11 @@ fi
# Export model.
EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}.pte"
echo "Exporting ${EXPORTED_MODEL_NAME}"
-$PYTHON_EXECUTABLE -m examples.models.llama2.export_llama -c stories110M.pt -p "${PARAMS}" -d "${DTYPE}"
+EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME}"
+if [[ "${MODE}" == "xnnpack" ]]; then
+ EXPORT_ARGS="${EXPORT_ARGS} --pt2e_quantize xnnpack_dynamic"
+fi
+$PYTHON_EXECUTABLE -m examples.models.llama2.export_llama ${EXPORT_ARGS}
# Create tokenizer.bin.
echo "Creating tokenizer.bin"
diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh
index 5ba8c57cdc..c7c00be257 100644
--- a/.ci/scripts/utils.sh
+++ b/.ci/scripts/utils.sh
@@ -134,8 +134,8 @@ cmake_install_executorch_lib() {
download_stories_model_artifacts() {
# Download stories110M.pt and tokenizer from Github
- wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
- wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
+ curl -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" --output stories110M.pt
+ curl -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" --output tokenizer.model
# Create params.json file
touch params.json
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py
new file mode 100755
index 0000000000..ece229fab0
--- /dev/null
+++ b/.github/scripts/cherry_pick.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import re
+from typing import Any, Optional
+
+from urllib.error import HTTPError
+
+from github_utils import gh_fetch_url, gh_post_pr_comment
+
+from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
+from trymerge import get_pr_commit_sha, GitHubPR
+
+
+# This is only a suggestion for now, not a strict requirement
+REQUIRES_ISSUE = {
+ "regression",
+ "critical",
+ "fixnewfeature",
+}
+
+
+def parse_args() -> Any:
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser("cherry pick a landed PR onto a release branch")
+ parser.add_argument(
+ "--onto-branch", type=str, required=True, help="the target release branch"
+ )
+ parser.add_argument(
+ "--github-actor", type=str, required=True, help="all the world’s a stage"
+ )
+ parser.add_argument(
+ "--classification",
+ choices=["regression", "critical", "fixnewfeature", "docs", "release"],
+ required=True,
+ help="the cherry pick category",
+ )
+ parser.add_argument("pr_num", type=int)
+ parser.add_argument(
+ "--fixes",
+ type=str,
+ default="",
+ help="the GitHub issue that the cherry pick fixes",
+ )
+ parser.add_argument("--dry-run", action="store_true")
+
+ return parser.parse_args()
+
+
+def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
+ """
+ Return the merge commit SHA iff the PR has been merged. For simplicity, we
+ will only cherry pick PRs that have been merged into main
+ """
+ commit_sha = get_pr_commit_sha(repo, pr)
+ return commit_sha if pr.is_closed() else None
+
+
+def cherry_pick(
+ github_actor: str,
+ repo: GitRepo,
+ pr: GitHubPR,
+ commit_sha: str,
+ onto_branch: str,
+ classification: str,
+ fixes: str,
+ dry_run: bool = False,
+) -> None:
+ """
+ Create a local branch to cherry pick the commit and submit it as a pull request
+ """
+ current_branch = repo.current_branch()
+ cherry_pick_branch = create_cherry_pick_branch(
+ github_actor, repo, pr, commit_sha, onto_branch
+ )
+
+ try:
+ if not dry_run:
+ org, project = repo.gh_owner_and_name()
+ cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
+
+ msg = f"The cherry pick PR is at {cherry_pick_pr}"
+ if fixes:
+ msg += f" and it is linked with issue {fixes}"
+ elif classification in REQUIRES_ISSUE:
+ msg += f" and it is recommended to link a {classification} cherry pick PR with an issue"
+
+ post_comment(org, project, pr.pr_num, msg)
+
+ finally:
+ if current_branch:
+ repo.checkout(branch=current_branch)
+
+
+def create_cherry_pick_branch(
+ github_actor: str, repo: GitRepo, pr: GitHubPR, commit_sha: str, onto_branch: str
+) -> str:
+ """
+ Create a local branch and cherry pick the commit. Return the name of the local
+ cherry picking branch.
+ """
+ repo.checkout(branch=onto_branch)
+ repo._run_git("submodule", "update", "--init", "--recursive")
+
+ # Remove all special characters if we want to include the actor in the branch name
+ github_actor = re.sub("[^0-9a-zA-Z]+", "_", github_actor)
+
+ cherry_pick_branch = f"cherry-pick-{pr.pr_num}-by-{github_actor}"
+ repo.create_branch_and_checkout(branch=cherry_pick_branch)
+
+ # We might want to support ghstack later
+ repo._run_git("cherry-pick", "-x", "-X", "theirs", commit_sha)
+ repo.push(branch=cherry_pick_branch, dry_run=False)
+
+ return cherry_pick_branch
+
+
+def submit_pr(
+ repo: GitRepo,
+ pr: GitHubPR,
+ cherry_pick_branch: str,
+ onto_branch: str,
+) -> str:
+ """
+ Submit the cherry pick PR and return the link to the PR
+ """
+ org, project = repo.gh_owner_and_name()
+
+ default_msg = f"Cherry pick #{pr.pr_num} onto {onto_branch} branch"
+ title = pr.info.get("title", default_msg)
+ body = pr.info.get("body", default_msg)
+
+ try:
+ response = gh_fetch_url(
+ f"https://api.github.com/repos/{org}/{project}/pulls",
+ method="POST",
+ data={
+ "title": title,
+ "body": body,
+ "head": cherry_pick_branch,
+ "base": onto_branch,
+ },
+ headers={"Accept": "application/vnd.github.v3+json"},
+ reader=json.load,
+ )
+
+ cherry_pick_pr = response.get("html_url", "")
+ if not cherry_pick_pr:
+ raise RuntimeError(
+ f"Fail to find the cherry pick PR: {json.dumps(response)}"
+ )
+
+ return str(cherry_pick_pr)
+
+ except HTTPError as error:
+ msg = f"Fail to submit the cherry pick PR: {error}"
+ raise RuntimeError(msg) from error
+
+
+def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
+ """
+ Post a comment on the PR itself to point to the cherry picking PR when success
+ or print the error when failure
+ """
+ internal_debugging = ""
+
+ run_url = os.getenv("GH_RUN_URL")
+ # Post a comment to tell folks that the PR is being cherry picked
+ if run_url is not None:
+ internal_debugging = "\n".join(
+ line
+ for line in (
+ "Details for Dev Infra team
",
+ f'Raised by workflow job\n',
+ " ",
+ )
+ if line
+ )
+
+ comment = "\n".join(
+ (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
+ )
+ gh_post_pr_comment(org, project, pr_num, comment)
+
+
+def main() -> None:
+ args = parse_args()
+ pr_num = args.pr_num
+
+ repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
+ org, project = repo.gh_owner_and_name()
+
+ pr = GitHubPR(org, project, pr_num)
+
+ try:
+ commit_sha = get_merge_commit_sha(repo, pr)
+ if not commit_sha:
+ raise RuntimeError(
+ f"Refuse to cherry pick #{pr_num} because it hasn't been merged yet"
+ )
+
+ cherry_pick(
+ args.github_actor,
+ repo,
+ pr,
+ commit_sha,
+ args.onto_branch,
+ args.classification,
+ args.fixes,
+ args.dry_run,
+ )
+
+ except RuntimeError as error:
+ if not args.dry_run:
+ post_comment(org, project, pr_num, str(error))
+ else:
+ raise error
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py
new file mode 100644
index 0000000000..9db5d64c2a
--- /dev/null
+++ b/.github/scripts/github_utils.py
@@ -0,0 +1,210 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""GitHub Utilities"""
+
+import json
+import os
+import warnings
+
+from dataclasses import dataclass
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+from urllib.error import HTTPError
+from urllib.parse import quote
+from urllib.request import Request, urlopen
+
+
+GITHUB_API_URL = "https://api.github.com"
+
+
+@dataclass
+class GitHubComment:
+ body_text: str
+ created_at: str
+ author_login: str
+ author_association: str
+ editor_login: Optional[str]
+ database_id: int
+ url: str
+
+
+def gh_fetch_url_and_headers(
+ url: str,
+ *,
+ headers: Optional[Dict[str, str]] = None,
+ data: Union[Optional[Dict[str, Any]], str] = None,
+ method: Optional[str] = None,
+ reader: Callable[[Any], Any] = lambda x: x.read(),
+) -> Tuple[Any, Any]:
+ if headers is None:
+ headers = {}
+ token = os.environ.get("GITHUB_TOKEN")
+ if token is not None and url.startswith(f"{GITHUB_API_URL}/"):
+ headers["Authorization"] = f"token {token}"
+
+ data_ = None
+ if data is not None:
+ data_ = data.encode() if isinstance(data, str) else json.dumps(data).encode()
+
+ try:
+ with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
+ return conn.headers, reader(conn)
+ except HTTPError as err:
+ if err.code == 403 and all(
+ key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"]
+ ):
+ print(
+ f"""Rate limit exceeded:
+ Used: {err.headers['X-RateLimit-Used']}
+ Limit: {err.headers['X-RateLimit-Limit']}
+ Remaining: {err.headers['X-RateLimit-Remaining']}
+ Resets at: {err.headers['x-RateLimit-Reset']}"""
+ )
+ raise
+
+
+def gh_fetch_url(
+ url: str,
+ *,
+ headers: Optional[Dict[str, str]] = None,
+ data: Union[Optional[Dict[str, Any]], str] = None,
+ method: Optional[str] = None,
+ reader: Callable[[Any], Any] = lambda x: x.read(),
+) -> Any:
+ return gh_fetch_url_and_headers(
+ url, headers=headers, data=data, reader=json.load, method=method
+ )[1]
+
+
+def gh_fetch_json(
+ url: str,
+ params: Optional[Dict[str, Any]] = None,
+ data: Optional[Dict[str, Any]] = None,
+ method: Optional[str] = None,
+) -> List[Dict[str, Any]]:
+ headers = {"Accept": "application/vnd.github.v3+json"}
+ if params is not None and len(params) > 0:
+ url += "?" + "&".join(
+ f"{name}={quote(str(val))}" for name, val in params.items()
+ )
+ return cast(
+ List[Dict[str, Any]],
+ gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
+ )
+
+
+def _gh_fetch_json_any(
+ url: str,
+ params: Optional[Dict[str, Any]] = None,
+ data: Optional[Dict[str, Any]] = None,
+) -> Any:
+ headers = {"Accept": "application/vnd.github.v3+json"}
+ if params is not None and len(params) > 0:
+ url += "?" + "&".join(
+ f"{name}={quote(str(val))}" for name, val in params.items()
+ )
+ return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
+
+
+def gh_fetch_json_list(
+ url: str,
+ params: Optional[Dict[str, Any]] = None,
+ data: Optional[Dict[str, Any]] = None,
+) -> List[Dict[str, Any]]:
+ return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
+
+
+def gh_fetch_json_dict(
+ url: str,
+ params: Optional[Dict[str, Any]] = None,
+ data: Optional[Dict[str, Any]] = None,
+) -> Dict[str, Any]:
+ return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
+
+
+def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
+ rc = gh_fetch_url(
+ "https://api.github.com/graphql",
+ data={"query": query, "variables": kwargs},
+ reader=json.load,
+ )
+ if "errors" in rc:
+ raise RuntimeError(
+ f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
+ )
+ return cast(Dict[str, Any], rc)
+
+
+def _gh_post_comment(
+ url: str, comment: str, dry_run: bool = False
+) -> List[Dict[str, Any]]:
+ if dry_run:
+ print(comment)
+ return []
+ return gh_fetch_json_list(url, data={"body": comment})
+
+
+def gh_post_pr_comment(
+ org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
+) -> List[Dict[str, Any]]:
+ return _gh_post_comment(
+ f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
+ comment,
+ dry_run,
+ )
+
+
+def gh_post_commit_comment(
+ org: str, repo: str, sha: str, comment: str, dry_run: bool = False
+) -> List[Dict[str, Any]]:
+ return _gh_post_comment(
+ f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
+ comment,
+ dry_run,
+ )
+
+
+def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
+ url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}"
+ gh_fetch_url(url, method="DELETE")
+
+
+def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:
+ merge_base = ""
+ # Get the merge base using the GitHub REST API. This is the same as using
+ # git merge-base without the need to have git. The API doc can be found at
+ # https://docs.github.com/en/rest/commits/commits?apiVersion=2022-11-28#compare-two-commits
+ try:
+ json_data = gh_fetch_url(
+ f"{GITHUB_API_URL}/repos/{org}/{repo}/compare/{base}...{head}",
+ headers={"Accept": "application/vnd.github.v3+json"},
+ reader=json.load,
+ )
+ if json_data:
+ merge_base = json_data.get("merge_base_commit", {}).get("sha", "")
+ else:
+ warnings.warn(
+ f"Failed to get merge base for {base}...{head}: Empty response"
+ )
+ except Exception as error:
+ warnings.warn(f"Failed to get merge base for {base}...{head}: {error}")
+
+ return merge_base
+
+
+def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> None:
+ url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}"
+ try:
+ gh_fetch_url(url, method="PATCH", data={"state": state})
+ except HTTPError as err:
+ # When trying to open the pull request, error 422 means that the branch
+ # has been deleted and the API couldn't re-open it
+ if err.code == 422 and state == "open":
+ warnings.warn(
+ f"Failed to open {pr_num} because its head branch has been deleted: {err}"
+ )
+ else:
+ raise
diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py
new file mode 100644
index 0000000000..7e4f63a162
--- /dev/null
+++ b/.github/scripts/gitutils.py
@@ -0,0 +1,457 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import re
+import tempfile
+from collections import defaultdict
+from datetime import datetime
+from functools import wraps
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+T = TypeVar("T")
+
+RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
+
+
+def get_git_remote_name() -> str:
+ return os.getenv("GIT_REMOTE_NAME", "origin")
+
+
+def get_git_repo_dir() -> str:
+ from pathlib import Path
+
+ return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
+
+
+def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
+ """
+ Converts list to dict preserving elements with duplicate keys
+ """
+ rc: Dict[str, List[str]] = defaultdict(list)
+ for key, val in items:
+ rc[key].append(val)
+ return dict(rc)
+
+
+def _check_output(items: List[str], encoding: str = "utf-8") -> str:
+ from subprocess import CalledProcessError, check_output, STDOUT
+
+ try:
+ return check_output(items, stderr=STDOUT).decode(encoding)
+ except CalledProcessError as e:
+ msg = f"Command `{' '.join(e.cmd)}` returned non-zero exit code {e.returncode}"
+ stdout = e.stdout.decode(encoding) if e.stdout is not None else ""
+ stderr = e.stderr.decode(encoding) if e.stderr is not None else ""
+ # These get swallowed up, so print them here for debugging
+ print(f"stdout: \n{stdout}")
+ print(f"stderr: \n{stderr}")
+ if len(stderr) == 0:
+ msg += f"\n```\n{stdout}```"
+ else:
+ msg += f"\nstdout:\n```\n{stdout}```\nstderr:\n```\n{stderr}```"
+ raise RuntimeError(msg) from e
+
+
+class GitCommit:
+ commit_hash: str
+ title: str
+ body: str
+ author: str
+ author_date: datetime
+ commit_date: Optional[datetime]
+
+ def __init__(
+ self,
+ commit_hash: str,
+ author: str,
+ author_date: datetime,
+ title: str,
+ body: str,
+ commit_date: Optional[datetime] = None,
+ ) -> None:
+ self.commit_hash = commit_hash
+ self.author = author
+ self.author_date = author_date
+ self.commit_date = commit_date
+ self.title = title
+ self.body = body
+
+ def __repr__(self) -> str:
+ return f"{self.title} ({self.commit_hash})"
+
+ def __contains__(self, item: Any) -> bool:
+ return item in self.body or item in self.title
+
+
+def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
+ """
+ Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
+ commit
+ Author:
+ AuthorDate:
+ Commit:
+ CommitDate:
+
+
+
+
+
+ """
+ if isinstance(lines, str):
+ lines = lines.split("\n")
+ # TODO: Handle merge commits correctly
+ if len(lines) > 1 and lines[1].startswith("Merge:"):
+ del lines[1]
+ assert len(lines) > 7
+ assert lines[0].startswith("commit")
+ assert lines[1].startswith("Author: ")
+ assert lines[2].startswith("AuthorDate: ")
+ assert lines[3].startswith("Commit: ")
+ assert lines[4].startswith("CommitDate: ")
+ assert len(lines[5]) == 0
+ return GitCommit(
+ commit_hash=lines[0].split()[1].strip(),
+ author=lines[1].split(":", 1)[1].strip(),
+ author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
+ commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
+ title=lines[6].strip(),
+ body="\n".join(lines[7:]),
+ )
+
+
+class GitRepo:
+ def __init__(self, path: str, remote: str = "origin", debug: bool = False) -> None:
+ self.repo_dir = path
+ self.remote = remote
+ self.debug = debug
+
+ def _run_git(self, *args: Any) -> str:
+ if self.debug:
+ print(f"+ git -C {self.repo_dir} {' '.join(args)}")
+ return _check_output(["git", "-C", self.repo_dir] + list(args))
+
+ def revlist(self, revision_range: str) -> List[str]:
+ rc = self._run_git("rev-list", revision_range, "--", ".").strip()
+ return rc.split("\n") if len(rc) > 0 else []
+
+ def branches_containing_ref(
+ self, ref: str, *, include_remote: bool = True
+ ) -> List[str]:
+ rc = (
+ self._run_git("branch", "--remote", "--contains", ref)
+ if include_remote
+ else self._run_git("branch", "--contains", ref)
+ )
+ return [x.strip() for x in rc.split("\n") if x.strip()] if len(rc) > 0 else []
+
+ def current_branch(self) -> Optional[str]:
+ try:
+ return self._run_git("symbolic-ref", "--short", "HEAD").strip()
+ except RuntimeError:
+ # we are in detached HEAD state
+ return None
+
+ def checkout(self, branch: str) -> None:
+ self._run_git("checkout", branch)
+
+ def create_branch_and_checkout(self, branch: str) -> None:
+ self._run_git("checkout", "-b", branch)
+
+ def fetch(self, ref: Optional[str] = None, branch: Optional[str] = None) -> None:
+ if branch is None and ref is None:
+ self._run_git("fetch", self.remote)
+ elif branch is None:
+ self._run_git("fetch", self.remote, ref)
+ else:
+ self._run_git("fetch", self.remote, f"{ref}:{branch}")
+
+ def show_ref(self, name: str) -> str:
+ refs = self._run_git("show-ref", "-s", name).strip().split("\n")
+ if not all(refs[i] == refs[0] for i in range(1, len(refs))):
+ raise RuntimeError(f"reference {name} is ambiguous")
+ return refs[0]
+
+ def rev_parse(self, name: str) -> str:
+ return self._run_git("rev-parse", "--verify", name).strip()
+
+ def get_merge_base(self, from_ref: str, to_ref: str) -> str:
+ return self._run_git("merge-base", from_ref, to_ref).strip()
+
+ def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
+ is_list = isinstance(ref, list)
+ if is_list:
+ if len(ref) == 0:
+ return []
+ ref = " ".join(ref)
+ rc = _check_output(
+ ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
+ ).strip()
+ return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
+
+ def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
+ owner, name = self.gh_owner_and_name()
+ msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
+ rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
+ return rc.split("\n") if len(rc) > 0 else []
+
+ def get_commit(self, ref: str) -> GitCommit:
+ return parse_fuller_format(
+ self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref)
+ )
+
+ def cherry_pick(self, ref: str) -> None:
+ self._run_git("cherry-pick", "-x", ref)
+
+ def revert(self, ref: str) -> None:
+ self._run_git("revert", "--no-edit", ref)
+
+ def compute_branch_diffs(
+ self, from_branch: str, to_branch: str
+ ) -> Tuple[List[str], List[str]]:
+ """
+ Returns list of commmits that are missing in each other branch since their merge base
+ Might be slow if merge base is between two branches is pretty far off
+ """
+ from_ref = self.rev_parse(from_branch)
+ to_ref = self.rev_parse(to_branch)
+ merge_base = self.get_merge_base(from_ref, to_ref)
+ from_commits = self.revlist(f"{merge_base}..{from_ref}")
+ to_commits = self.revlist(f"{merge_base}..{to_ref}")
+ from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
+ to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
+ for patch_id in set(from_ids).intersection(set(to_ids)):
+ from_values = from_ids[patch_id]
+ to_values = to_ids[patch_id]
+ if len(from_values) != len(to_values):
+ # Eliminate duplicate commits+reverts from the list
+ while len(from_values) > 0 and len(to_values) > 0:
+ frc = self.get_commit(from_values.pop())
+ toc = self.get_commit(to_values.pop())
+ # FRC branch might have PR number added to the title
+ if frc.title != toc.title or frc.author_date != toc.author_date:
+ # HACK: Same commit were merged, reverted and landed again
+ # which creates a tracking problem
+ if (
+ "pytorch/pytorch" not in self.remote_url()
+ or frc.commit_hash
+ not in {
+ "0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
+ "6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
+ "edf909e58f06150f7be41da2f98a3b9de3167bca",
+ "a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
+ "7106d216c29ca16a3504aa2bedad948ebcf4abc2",
+ }
+ ):
+ raise RuntimeError(
+ f"Unexpected differences between {frc} and {toc}"
+ )
+ from_commits.remove(frc.commit_hash)
+ to_commits.remove(toc.commit_hash)
+ continue
+ for commit in from_values:
+ from_commits.remove(commit)
+ for commit in to_values:
+ to_commits.remove(commit)
+ # Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits
+ # I.e. cherry-picking those from one branch into another will change patchid
+ if "pytorch/pytorch" in self.remote_url():
+ for excluded_commit in {
+ "8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
+ "5f37e5c2a39c3acb776756a17730b865f0953432",
+ "b5222584e6d6990c6585981a936defd1af14c0ba",
+ "84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
+ "f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e",
+ }:
+ if excluded_commit in from_commits:
+ from_commits.remove(excluded_commit)
+
+ return (from_commits, to_commits)
+
+ def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None:
+ orig_branch = self.current_branch()
+ assert orig_branch is not None, "Must be on a branch"
+ self.checkout(to_branch)
+ from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch)
+ if len(from_commits) == 0:
+ print("Nothing to do")
+ self.checkout(orig_branch)
+ return
+ for commit in reversed(from_commits):
+ print(f"Cherry picking commit {commit}")
+ self.cherry_pick(commit)
+ self.checkout(orig_branch)
+
+ def push(self, branch: str, dry_run: bool, retry: int = 3) -> None:
+ for cnt in range(retry):
+ try:
+ if dry_run:
+ self._run_git("push", "--dry-run", self.remote, branch)
+ else:
+ self._run_git("push", self.remote, branch)
+ except RuntimeError as e:
+ print(f"{cnt} push attempt failed with {e}")
+ self.fetch()
+ self._run_git("rebase", f"{self.remote}/{branch}")
+
+ def head_hash(self) -> str:
+ return self._run_git("show-ref", "--hash", "HEAD").strip()
+
+ def remote_url(self) -> str:
+ return self._run_git("remote", "get-url", self.remote)
+
+ def gh_owner_and_name(self) -> Tuple[str, str]:
+ url = os.getenv("GIT_REMOTE_URL", None)
+ if url is None:
+ url = self.remote_url()
+ rc = RE_GITHUB_URL_MATCH.match(url)
+ if rc is None:
+ raise RuntimeError(f"Unexpected url format {url}")
+ return cast(Tuple[str, str], rc.groups())
+
+ def commit_message(self, ref: str) -> str:
+ return self._run_git("log", "-1", "--format=%B", ref)
+
+ def amend_commit_message(self, msg: str) -> None:
+ self._run_git("commit", "--amend", "-m", msg)
+
+ def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str:
+ if to_ref is None:
+ return self._run_git("diff", f"{from_ref}^!")
+ return self._run_git("diff", f"{from_ref}..{to_ref}")
+
+
+def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo:
+ path = tempfile.mkdtemp()
+ _check_output(
+ [
+ "git",
+ "clone",
+ f"https://{username}:{password}@github.com/{org}/{project}",
+ path,
+ ]
+ ).strip()
+ return GitRepo(path=path)
+
+
+class PeekableIterator(Iterator[str]):
+ def __init__(self, val: str) -> None:
+ self._val = val
+ self._idx = -1
+
+ def peek(self) -> Optional[str]:
+ if self._idx + 1 >= len(self._val):
+ return None
+ return self._val[self._idx + 1]
+
+ def __iter__(self) -> "PeekableIterator":
+ return self
+
+ def __next__(self) -> str:
+ rc = self.peek()
+ if rc is None:
+ raise StopIteration
+ self._idx += 1
+ return rc
+
+
+def patterns_to_regex(allowed_patterns: List[str]) -> Any:
+ """
+ pattern is glob-like, i.e. the only special sequences it has are:
+ - ? - matches single character
+ - * - matches any non-folder separator characters or no character
+ - ** - matches any characters or no character
+ Assuming that patterns are free of braces and backslashes
+ the only character that needs to be escaped are dot and plus
+ """
+ rc = "("
+ for idx, pattern in enumerate(allowed_patterns):
+ if idx > 0:
+ rc += "|"
+ pattern_ = PeekableIterator(pattern)
+ assert not any(c in pattern for c in "{}()[]\\")
+ for c in pattern_:
+ if c == ".":
+ rc += "\\."
+ elif c == "+":
+ rc += "\\+"
+ elif c == "*":
+ if pattern_.peek() == "*":
+ next(pattern_)
+ rc += ".*"
+ else:
+ rc += "[^/]*"
+ else:
+ rc += c
+ rc += ")"
+ return re.compile(rc)
+
+
+def _shasum(value: str) -> str:
+ import hashlib
+
+ m = hashlib.sha256()
+ m.update(value.encode("utf-8"))
+ return m.hexdigest()
+
+
+def is_commit_hash(ref: str) -> bool:
+ "True if ref is hexadecimal number, else false"
+ try:
+ int(ref, 16)
+ except ValueError:
+ return False
+ return True
+
+
+def are_ghstack_branches_in_sync(
+ repo: GitRepo, head_ref: str, base_ref: Optional[str] = None
+) -> bool:
+ """Checks that diff between base and head is the same as diff between orig and its parent"""
+ orig_ref = re.sub(r"/head$", "/orig", head_ref)
+ if base_ref is None:
+ base_ref = re.sub(r"/head$", "/base", head_ref)
+ orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
+ head_diff_sha = _shasum(
+ repo.diff(
+ base_ref if is_commit_hash(base_ref) else f"{repo.remote}/{base_ref}",
+ f"{repo.remote}/{head_ref}",
+ )
+ )
+ return orig_diff_sha == head_diff_sha
+
+
+def retries_decorator(
+ rc: Any = None, num_retries: int = 3
+) -> Callable[[Callable[..., T]], Callable[..., T]]:
+ def decorator(f: Callable[..., T]) -> Callable[..., T]:
+ @wraps(f)
+ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
+ for idx in range(num_retries):
+ try:
+ return f(*args, **kwargs)
+ except Exception as e:
+ print(
+ f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"'
+ )
+ pass
+ return cast(T, rc)
+
+ return wrapper
+
+ return decorator
diff --git a/.github/scripts/label_utils.py b/.github/scripts/label_utils.py
new file mode 100644
index 0000000000..81668dad0c
--- /dev/null
+++ b/.github/scripts/label_utils.py
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""GitHub Label Utilities."""
+
+import json
+
+from functools import lru_cache
+from typing import Any, List, Tuple, TYPE_CHECKING, Union
+
+from github_utils import gh_fetch_url_and_headers, GitHubComment
+
+# TODO: this is a temp workaround to avoid circular dependencies,
+# and should be removed once GitHubPR is refactored out of trymerge script.
+if TYPE_CHECKING:
+ from trymerge import GitHubPR
+
+BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
+
+LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
+LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
+If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.
+
+If not, please add the `topic: not user facing` label.
+
+To add a label, you can comment to pytorchbot, for example
+`@pytorchbot label "topic: not user facing"`
+
+For more information, see
+https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.
+"""
+
+
+def request_for_labels(url: str) -> Tuple[Any, Any]:
+ headers = {"Accept": "application/vnd.github.v3+json"}
+ return gh_fetch_url_and_headers(
+ url, headers=headers, reader=lambda x: x.read().decode("utf-8")
+ )
+
+
+def update_labels(labels: List[str], info: str) -> None:
+ labels_json = json.loads(info)
+ labels.extend([x["name"] for x in labels_json])
+
+
+def get_last_page_num_from_header(header: Any) -> int:
+ # Link info looks like: ;
+ # rel="next", ; rel="last"
+ link_info = header["link"]
+ # Docs does not specify that it should be present for projects with just few labels
+ # And https://github.com/malfet/deleteme/actions/runs/7334565243/job/19971396887 it's not the case
+ if link_info is None:
+ return 1
+ prefix = "&page="
+ suffix = ">;"
+ return int(
+ link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)]
+ )
+
+
+@lru_cache
+def gh_get_labels(org: str, repo: str) -> List[str]:
+ prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
+ header, info = request_for_labels(prefix + "&page=1")
+ labels: List[str] = []
+ update_labels(labels, info)
+
+ last_page = get_last_page_num_from_header(header)
+ assert (
+ last_page > 0
+ ), "Error reading header info to determine total number of pages of labels"
+ for page_number in range(2, last_page + 1): # skip page 1
+ _, info = request_for_labels(prefix + f"&page={page_number}")
+ update_labels(labels, info)
+
+ return labels
+
+
+def gh_add_labels(
+ org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
+) -> None:
+ if dry_run:
+ print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
+ return
+ gh_fetch_url_and_headers(
+ url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
+ data={"labels": labels},
+ )
+
+
+def gh_remove_label(
+ org: str, repo: str, pr_num: int, label: str, dry_run: bool
+) -> None:
+ if dry_run:
+ print(f"Dryrun: Removing {label} from PR {pr_num}")
+ return
+ gh_fetch_url_and_headers(
+ url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels/{label}",
+ method="DELETE",
+ )
+
+
+def get_release_notes_labels(org: str, repo: str) -> List[str]:
+ return [
+ label
+ for label in gh_get_labels(org, repo)
+ if label.lstrip().startswith("release notes:")
+ ]
+
+
+def has_required_labels(pr: "GitHubPR") -> bool:
+ pr_labels = pr.get_labels()
+ # Check if PR is not user facing
+ is_not_user_facing_pr = any(
+ label.strip() == "topic: not user facing" for label in pr_labels
+ )
+ return is_not_user_facing_pr or any(
+ label.strip() in get_release_notes_labels(pr.org, pr.project)
+ for label in pr_labels
+ )
+
+
+def is_label_err_comment(comment: GitHubComment) -> bool:
+ # comment.body_text returns text without markdown
+ no_format_title = LABEL_ERR_MSG_TITLE.replace("`", "")
+ return (
+ comment.body_text.lstrip(" #").startswith(no_format_title)
+ and comment.author_login in BOT_AUTHORS
+ )
diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py
new file mode 100755
index 0000000000..0f0e1e30cf
--- /dev/null
+++ b/.github/scripts/trymerge.py
@@ -0,0 +1,2372 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# NB: the following functions are used in Meta-internal workflows
+# (github_first_try_merge/my_handler.py) and thus have functionality limitations
+# (no `git` command access, no network access besides the strict allow list):
+#
+# find_matching_merge_rule
+# read_merge_rules
+#
+# Also any signature changes of these functions, as well as changes to the `GitHubPR`
+# class, will likely require corresponding changes for the internal workflows.
+
+import base64
+import json
+import os
+import re
+import time
+import urllib.parse
+from collections import defaultdict
+from dataclasses import dataclass
+from functools import lru_cache
+from pathlib import Path
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Iterable,
+ List,
+ NamedTuple,
+ Optional,
+ Pattern,
+ Tuple,
+)
+from warnings import warn
+
+import yaml
+from github_utils import (
+ gh_fetch_json_list,
+ gh_fetch_merge_base,
+ gh_fetch_url,
+ gh_graphql,
+ gh_post_commit_comment,
+ gh_post_pr_comment,
+ gh_update_pr_state,
+ GitHubComment,
+)
+
+from gitutils import (
+ are_ghstack_branches_in_sync,
+ get_git_remote_name,
+ get_git_repo_dir,
+ GitRepo,
+ patterns_to_regex,
+ retries_decorator,
+)
+from label_utils import (
+ gh_add_labels,
+ gh_remove_label,
+ has_required_labels,
+ LABEL_ERR_MSG,
+)
+from trymerge_explainer import get_revert_message, TryMergeExplainer
+
+# labels
+MERGE_IN_PROGRESS_LABEL = "merging"
+MERGE_COMPLETE_LABEL = "merged"
+
+
+class JobCheckState(NamedTuple):
+ name: str
+ url: str
+ status: Optional[str]
+ classification: Optional[str]
+ job_id: Optional[int]
+ title: Optional[str]
+ summary: Optional[str]
+
+
+JobNameToStateDict = Dict[str, JobCheckState]
+
+
+class WorkflowCheckState:
+ def __init__(self, name: str, url: str, status: Optional[str]):
+ self.name: str = name
+ self.url: str = url
+ self.status: Optional[str] = status
+ self.jobs: JobNameToStateDict = {}
+
+
+GH_PR_REVIEWS_FRAGMENT = """
+fragment PRReviews on PullRequestReviewConnection {
+ nodes {
+ author {
+ login
+ }
+ bodyText
+ createdAt
+ authorAssociation
+ editor {
+ login
+ }
+ databaseId
+ url
+ state
+ }
+ pageInfo {
+ startCursor
+ hasPreviousPage
+ }
+}
+"""
+
+GH_CHECKSUITES_FRAGMENT = """
+fragment PRCheckSuites on CheckSuiteConnection {
+ edges {
+ node {
+ app {
+ name
+ databaseId
+ }
+ workflowRun {
+ workflow {
+ name
+ }
+ url
+ }
+ checkRuns(first: 50) {
+ nodes {
+ name
+ conclusion
+ detailsUrl
+ databaseId
+ title
+ summary
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+ }
+ conclusion
+ }
+ cursor
+ }
+ pageInfo {
+ hasNextPage
+ }
+}
+"""
+
+GH_COMMIT_AUTHORS_FRAGMENT = """
+fragment CommitAuthors on PullRequestCommitConnection {
+ nodes {
+ commit {
+ authors(first: 2) {
+ nodes {
+ user {
+ login
+ }
+ email
+ name
+ }
+ }
+ oid
+ }
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+}
+"""
+
+GH_GET_PR_INFO_QUERY = (
+ GH_PR_REVIEWS_FRAGMENT
+ + GH_CHECKSUITES_FRAGMENT
+ + GH_COMMIT_AUTHORS_FRAGMENT
+ + """
+query ($owner: String!, $name: String!, $number: Int!) {
+ repository(owner: $owner, name: $name) {
+ pullRequest(number: $number) {
+ closed
+ isCrossRepository
+ author {
+ login
+ }
+ title
+ body
+ headRefName
+ headRepository {
+ nameWithOwner
+ }
+ baseRefName
+ baseRefOid
+ baseRepository {
+ nameWithOwner
+ isPrivate
+ defaultBranchRef {
+ name
+ }
+ }
+ mergeCommit {
+ oid
+ }
+ commits_with_authors: commits(first: 100) {
+ ...CommitAuthors
+ totalCount
+ }
+ commits(last: 1) {
+ nodes {
+ commit {
+ checkSuites(first: 10) {
+ ...PRCheckSuites
+ }
+ status {
+ contexts {
+ context
+ state
+ targetUrl
+ }
+ }
+ oid
+ }
+ }
+ }
+ changedFiles
+ files(first: 100) {
+ nodes {
+ path
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+ }
+ reviews(last: 100) {
+ ...PRReviews
+ }
+ comments(last: 5) {
+ nodes {
+ bodyText
+ createdAt
+ author {
+ login
+ }
+ authorAssociation
+ editor {
+ login
+ }
+ databaseId
+ url
+ }
+ pageInfo {
+ startCursor
+ hasPreviousPage
+ }
+ }
+ labels(first: 100) {
+ edges {
+ node {
+ name
+ }
+ }
+ }
+ }
+ }
+}
+"""
+)
+
+GH_GET_PR_NEXT_FILES_QUERY = """
+query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ files(first: 100, after: $cursor) {
+ nodes {
+ path
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+ }
+ }
+ }
+}
+"""
+
+GH_GET_PR_NEXT_CHECKSUITES = (
+ GH_CHECKSUITES_FRAGMENT
+ + """
+query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ commits(last: 1) {
+ nodes {
+ commit {
+ oid
+ checkSuites(first: 10, after: $cursor) {
+ ...PRCheckSuites
+ }
+ }
+ }
+ }
+ }
+ }
+}
+"""
+)
+
+GH_GET_PR_NEXT_CHECK_RUNS = """
+query ($owner: String!, $name: String!, $number: Int!, $cs_cursor: String, $cr_cursor: String!) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ commits(last: 1) {
+ nodes {
+ commit {
+ oid
+ checkSuites(first: 1, after: $cs_cursor) {
+ nodes {
+ checkRuns(first: 100, after: $cr_cursor) {
+ nodes {
+ name
+ conclusion
+ detailsUrl
+ databaseId
+ title
+ summary
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+"""
+
+GH_GET_PR_PREV_COMMENTS = """
+query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ comments(last: 100, before: $cursor) {
+ nodes {
+ bodyText
+ createdAt
+ author {
+ login
+ }
+ authorAssociation
+ editor {
+ login
+ }
+ databaseId
+ url
+ }
+ pageInfo {
+ startCursor
+ hasPreviousPage
+ }
+ }
+ }
+ }
+}
+"""
+
+# This query needs read-org permission
+GH_GET_TEAM_MEMBERS_QUERY = """
+query($org: String!, $name: String!, $cursor: String) {
+ organization(login: $org) {
+ team(slug: $name) {
+ members(first: 100, after: $cursor) {
+ nodes {
+ login
+ }
+ pageInfo {
+ hasNextPage
+ endCursor
+ }
+ }
+ }
+ }
+}
+"""
+
+GH_GET_PR_NEXT_AUTHORS_QUERY = (
+ GH_COMMIT_AUTHORS_FRAGMENT
+ + """
+query ($owner: String!, $name: String!, $number: Int!, $cursor: String) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ commits_with_authors: commits(first: 100, after: $cursor) {
+ ...CommitAuthors
+ }
+ }
+ }
+}
+"""
+)
+
+GH_GET_PR_PREV_REVIEWS_QUERY = (
+ GH_PR_REVIEWS_FRAGMENT
+ + """
+query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
+ repository(name: $name, owner: $owner) {
+ pullRequest(number: $number) {
+ reviews(last: 100, before: $cursor) {
+ ...PRReviews
+ }
+ }
+ }
+}
+"""
+)
+
+GH_GET_REPO_SUBMODULES = """
+query ($owner: String!, $name: String!) {
+ repository(owner: $owner, name: $name) {
+ submodules(first: 100) {
+ nodes {
+ path
+ }
+ pageInfo {
+ endCursor
+ hasNextPage
+ }
+ }
+ }
+}
+"""
+
+RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
+RE_GHSTACK_DESC = re.compile(r"Stack.*:\r?\n(\* [^\r\n]+\r?\n)+", re.MULTILINE)
+RE_PULL_REQUEST_RESOLVED = re.compile(
+ r"Pull Request resolved: "
+ r"https://github.com/(?P[^/]+)/(?P[^/]+)/pull/(?P[0-9]+)",
+ re.MULTILINE,
+)
+RE_PR_CC_LINE = re.compile(r"^cc:? @\w+.*\r?\n?$", re.MULTILINE)
+RE_DIFF_REV = re.compile(r"^Differential Revision:.+?(D[0-9]+)", re.MULTILINE)
+CIFLOW_LABEL = re.compile(r"^ciflow/.+")
+CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
+MERGE_RULE_PATH = Path(".github") / "merge_rules.yaml"
+ROCKSET_MERGES_COLLECTION = "merges"
+ROCKSET_MERGES_WORKSPACE = "commons"
+REMOTE_MAIN_BRANCH = "origin/main"
+DRCI_CHECKRUN_NAME = "Dr.CI"
+INTERNAL_CHANGES_CHECKRUN_NAME = "Meta Internal-Only Changes Check"
+HAS_NO_CONNECTED_DIFF_TITLE = (
+ "There is no internal Diff connected, this can be merged now"
+)
+# This could be set to -1 to ignore all flaky and broken trunk failures. On the
+# other hand, using a large value like 10 here might be useful in sev situation
+IGNORABLE_FAILED_CHECKS_THESHOLD = 10
+
+
+def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
+ rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
+ return rc["data"]["repository"]["pullRequest"]
+
+
+@lru_cache(maxsize=None)
+def gh_get_team_members(org: str, name: str) -> List[str]:
+ rc: List[str] = []
+ team_members: Dict[str, Any] = {
+ "pageInfo": {"hasNextPage": "true", "endCursor": None}
+ }
+ while bool(team_members["pageInfo"]["hasNextPage"]):
+ query = gh_graphql(
+ GH_GET_TEAM_MEMBERS_QUERY,
+ org=org,
+ name=name,
+ cursor=team_members["pageInfo"]["endCursor"],
+ )
+ team = query["data"]["organization"]["team"]
+ if team is None:
+ warn(f"Requested non-existing team {org}/{name}")
+ return []
+ team_members = team["members"]
+ rc += [member["login"] for member in team_members["nodes"]]
+ return rc
+
+
+def get_check_run_name_prefix(workflow_run: Any) -> str:
+ if workflow_run is None:
+ return ""
+ else:
+ return f'{workflow_run["workflow"]["name"]} / '
+
+
+def is_passing_status(status: Optional[str]) -> bool:
+ return status is not None and status.upper() in ["SUCCESS", "SKIPPED", "NEUTRAL"]
+
+
+def add_workflow_conclusions(
+ checksuites: Any,
+ get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any],
+ get_next_checksuites: Callable[[Any], Any],
+) -> JobNameToStateDict:
+ # graphql seems to favor the most recent workflow run, so in theory we
+ # shouldn't need to account for reruns, but do it just in case
+
+ # workflow -> job -> job info
+ workflows: Dict[str, WorkflowCheckState] = {}
+
+ # for the jobs that don't have a workflow
+ no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", None)
+
+ def add_conclusions(edges: Any) -> None:
+ for edge_idx, edge in enumerate(edges):
+ node = edge["node"]
+ workflow_run = node["workflowRun"]
+ checkruns = node["checkRuns"]
+
+ workflow_obj: WorkflowCheckState = no_workflow_obj
+
+ if workflow_run is not None:
+ workflow_name = workflow_run["workflow"]["name"]
+ workflow_conclusion = node["conclusion"]
+ # Do not override existing status with cancelled
+ if workflow_conclusion == "CANCELLED" and workflow_name in workflows:
+ continue
+ if workflow_name not in workflows:
+ workflows[workflow_name] = WorkflowCheckState(
+ name=workflow_name,
+ status=workflow_conclusion,
+ url=workflow_run["url"],
+ )
+ workflow_obj = workflows[workflow_name]
+
+ while checkruns is not None:
+ for checkrun_node in checkruns["nodes"]:
+ if not isinstance(checkrun_node, dict):
+ warn(f"Expected dictionary, but got {type(checkrun_node)}")
+ continue
+ checkrun_name = f'{get_check_run_name_prefix(workflow_run)}{checkrun_node["name"]}'
+ existing_checkrun = workflow_obj.jobs.get(checkrun_name)
+ if existing_checkrun is None or not is_passing_status(
+ existing_checkrun.status
+ ):
+ workflow_obj.jobs[checkrun_name] = JobCheckState(
+ checkrun_name,
+ checkrun_node["detailsUrl"],
+ checkrun_node["conclusion"],
+ classification=None,
+ job_id=checkrun_node["databaseId"],
+ title=checkrun_node["title"],
+ summary=checkrun_node["summary"],
+ )
+
+ if bool(checkruns["pageInfo"]["hasNextPage"]):
+ checkruns = get_next_checkruns_page(edges, edge_idx, checkruns)
+ else:
+ checkruns = None
+
+ all_edges = checksuites["edges"].copy()
+ while bool(checksuites["pageInfo"]["hasNextPage"]):
+ checksuites = get_next_checksuites(checksuites)
+ all_edges.extend(checksuites["edges"])
+
+ add_conclusions(all_edges)
+
+ # Flatten the dictionaries. If there exists jobs in the workflow run, put
+ # the jobs in but don't put the workflow in. We care more about the jobs in
+ # the workflow that ran than the container workflow.
+ res: JobNameToStateDict = {}
+ for workflow_name, workflow in workflows.items():
+ if len(workflow.jobs) > 0:
+ for job_name, job in workflow.jobs.items():
+ res[job_name] = job
+ else:
+ res[workflow_name] = JobCheckState(
+ workflow.name,
+ workflow.url,
+ workflow.status,
+ classification=None,
+ job_id=None,
+ title=None,
+ summary=None,
+ )
+ for job_name, job in no_workflow_obj.jobs.items():
+ res[job_name] = job
+ return res
+
+
+def parse_args() -> Any:
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser("Merge PR into default branch")
+ parser.add_argument("--dry-run", action="store_true")
+ parser.add_argument("--revert", action="store_true")
+ parser.add_argument("--force", action="store_true")
+ parser.add_argument("--ignore-current", action="store_true")
+ parser.add_argument("--check-mergeability", action="store_true")
+ parser.add_argument("--comment-id", type=int)
+ parser.add_argument("--reason", type=str)
+ parser.add_argument("pr_num", type=int)
+ return parser.parse_args()
+
+
+def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -> bool:
+ if comment_id is None:
+ return False
+ comment = pr.get_comment_by_id(comment_id)
+ if comment.editor_login is not None:
+ return False
+ return comment.author_login == "facebook-github-bot"
+
+
+def _revlist_to_prs(
+ repo: GitRepo,
+ pr: "GitHubPR",
+ rev_list: Iterable[str],
+ should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
+) -> List[Tuple["GitHubPR", str]]:
+ rc: List[Tuple[GitHubPR, str]] = []
+ for idx, rev in enumerate(rev_list):
+ msg = repo.commit_message(rev)
+ m = RE_PULL_REQUEST_RESOLVED.search(msg)
+ if m is None:
+ raise RuntimeError(
+ f"Could not find PR-resolved string in {msg} of ghstacked PR {pr.pr_num}"
+ )
+ if pr.org != m.group("owner") or pr.project != m.group("repo"):
+ raise RuntimeError(
+ f"PR {m.group('number')} resolved to wrong owner/repo pair"
+ )
+ pr_num = int(m.group("number"))
+ candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
+ if should_skip is not None and should_skip(idx, candidate):
+ continue
+ rc.append((candidate, rev))
+ return rc
+
+
+def get_ghstack_prs(
+ repo: GitRepo, pr: "GitHubPR", open_only: bool = True
+) -> List[Tuple["GitHubPR", str]]:
+ """
+ Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
+ @:param open_only: Only return open PRs
+ """
+ # For ghstack, cherry-pick commits based from origin
+ orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
+ rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
+
+ def skip_func(idx: int, candidate: "GitHubPR") -> bool:
+ if not open_only or not candidate.is_closed():
+ return False
+ print(
+ f"Skipping {idx+1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
+ )
+ return True
+
+ assert pr.is_ghstack_pr()
+ entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)
+
+ for stacked_pr, rev in entire_stack:
+ if stacked_pr.is_closed():
+ continue
+ base_ref = stacked_pr.base_ref()
+ if base_ref == pr.default_branch():
+ base_ref = repo.get_merge_base(
+ f"{repo.remote}/{base_ref}", f"{repo.remote}/{stacked_pr.head_ref()}"
+ )
+ if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref(), base_ref):
+ raise RuntimeError(
+ f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
+ + f"branch {stacked_pr.get_ghstack_orig_ref()} that would be merged into {stacked_pr.default_branch()}. "
+ + "This usually happens because there is a non ghstack change in the PR. "
+ + f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)."
+ )
+ return entire_stack
+
+
+class GitHubPR:
+ def __init__(self, org: str, project: str, pr_num: int) -> None:
+ assert isinstance(pr_num, int)
+ self.org = org
+ self.project = project
+ self.pr_num = pr_num
+ self.info = gh_get_pr_info(org, project, pr_num)
+ self.changed_files: Optional[List[str]] = None
+ self.labels: Optional[List[str]] = None
+ self.conclusions: Optional[JobNameToStateDict] = None
+ self.comments: Optional[List[GitHubComment]] = None
+ self._authors: Optional[List[Tuple[str, str]]] = None
+ self._reviews: Optional[List[Tuple[str, str]]] = None
+ self.merge_base: Optional[str] = None
+ self.submodules: Optional[List[str]] = None
+
+ def is_closed(self) -> bool:
+ return bool(self.info["closed"])
+
+ def is_cross_repo(self) -> bool:
+ return bool(self.info["isCrossRepository"])
+
+ def base_ref(self) -> str:
+ return cast(str, self.info["baseRefName"])
+
+ def default_branch(self) -> str:
+ return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
+
+ def head_ref(self) -> str:
+ return cast(str, self.info["headRefName"])
+
+ def is_ghstack_pr(self) -> bool:
+ return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
+
+ def get_ghstack_orig_ref(self) -> str:
+ assert self.is_ghstack_pr()
+ return re.sub(r"/head$", "/orig", self.head_ref())
+
+ def is_base_repo_private(self) -> bool:
+ return bool(self.info["baseRepository"]["isPrivate"])
+
+ def get_changed_files_count(self) -> int:
+ return int(self.info["changedFiles"])
+
+ def last_commit(self) -> Any:
+ return self.info["commits"]["nodes"][-1]["commit"]
+
+ def get_merge_base(self) -> str:
+ if self.merge_base:
+ return self.merge_base
+
+ last_commit_oid = self.last_commit()["oid"]
+ # NB: We could use self.base_ref() here for regular PR, however, that doesn't
+ # work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
+ # so let's just use main instead
+ self.merge_base = gh_fetch_merge_base(
+ self.org, self.project, last_commit_oid, self.default_branch()
+ )
+
+ # Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
+ # points to the base ref associated with the PR or, in other words, the head of main
+ # when the PR is created or rebased. This is not necessarily the merge base commit,
+ # but it could serve as a fallback in most cases and it's readily available as part
+ # of the PR info
+ if not self.merge_base:
+ self.merge_base = cast(str, self.info["baseRefOid"])
+
+ return self.merge_base
+
+ def get_changed_files(self) -> List[str]:
+ if self.changed_files is None:
+ info = self.info
+ unique_changed_files = set()
+ # Do not try to fetch more than 10K files
+ for _ in range(100):
+ unique_changed_files.update([x["path"] for x in info["files"]["nodes"]])
+ if not info["files"]["pageInfo"]["hasNextPage"]:
+ break
+ rc = gh_graphql(
+ GH_GET_PR_NEXT_FILES_QUERY,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cursor=info["files"]["pageInfo"]["endCursor"],
+ )
+ info = rc["data"]["repository"]["pullRequest"]
+ self.changed_files = list(unique_changed_files)
+
+ if len(self.changed_files) != self.get_changed_files_count():
+ raise RuntimeError("Changed file count mismatch")
+ return self.changed_files
+
+ def get_submodules(self) -> List[str]:
+ if self.submodules is None:
+ rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
+ info = rc["data"]["repository"]["submodules"]
+ self.submodules = [s["path"] for s in info["nodes"]]
+ return self.submodules
+
+ def get_changed_submodules(self) -> List[str]:
+ submodules = self.get_submodules()
+ return [f for f in self.get_changed_files() if f in submodules]
+
+ def has_invalid_submodule_updates(self) -> bool:
+ """Submodule updates in PR are invalid if submodule keyword
+ is not mentioned in neither the title nor body/description
+ nor in any of the labels.
+ """
+ return (
+ len(self.get_changed_submodules()) > 0
+ and "submodule" not in self.get_title().lower()
+ and "submodule" not in self.get_body().lower()
+ and all("submodule" not in label for label in self.get_labels())
+ )
+
+ def _get_reviews(self) -> List[Tuple[str, str]]:
+ if self._reviews is None:
+ self._reviews = []
+ info = self.info
+ for _ in range(100):
+ nodes = info["reviews"]["nodes"]
+ self._reviews = [
+ (node["author"]["login"], node["state"]) for node in nodes
+ ] + self._reviews
+ if not info["reviews"]["pageInfo"]["hasPreviousPage"]:
+ break
+ rc = gh_graphql(
+ GH_GET_PR_PREV_REVIEWS_QUERY,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cursor=info["reviews"]["pageInfo"]["startCursor"],
+ )
+ info = rc["data"]["repository"]["pullRequest"]
+ reviews = {}
+ for author, state in self._reviews:
+ if state != "COMMENTED":
+ reviews[author] = state
+ return list(reviews.items())
+
+ def get_approved_by(self) -> List[str]:
+ return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
+
+ def get_commit_count(self) -> int:
+ return int(self.info["commits_with_authors"]["totalCount"])
+
+ def get_pr_creator_login(self) -> str:
+ return cast(str, self.info["author"]["login"])
+
+ def _fetch_authors(self) -> List[Tuple[str, str]]:
+ if self._authors is not None:
+ return self._authors
+ authors: List[Tuple[str, str]] = []
+
+ def add_authors(info: Dict[str, Any]) -> None:
+ for node in info["commits_with_authors"]["nodes"]:
+ for author_node in node["commit"]["authors"]["nodes"]:
+ user_node = author_node["user"]
+ author = f"{author_node['name']} <{author_node['email']}>"
+ if user_node is None:
+ # If author is not github user, user node will be null
+ authors.append(("", author))
+ else:
+ authors.append((cast(str, user_node["login"]), author))
+
+ info = self.info
+ for _ in range(100):
+ add_authors(info)
+ if not info["commits_with_authors"]["pageInfo"]["hasNextPage"]:
+ break
+ rc = gh_graphql(
+ GH_GET_PR_NEXT_AUTHORS_QUERY,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cursor=info["commits_with_authors"]["pageInfo"]["endCursor"],
+ )
+ info = rc["data"]["repository"]["pullRequest"]
+ self._authors = authors
+ return authors
+
+ def get_committer_login(self, num: int = 0) -> str:
+ return self._fetch_authors()[num][0]
+
+ def get_committer_author(self, num: int = 0) -> str:
+ return self._fetch_authors()[num][1]
+
+ def get_labels(self) -> List[str]:
+ if self.labels is not None:
+ return self.labels
+ labels = (
+ [node["node"]["name"] for node in self.info["labels"]["edges"]]
+ if "labels" in self.info
+ else []
+ )
+ self.labels = labels
+ return self.labels
+
+ def get_checkrun_conclusions(self) -> JobNameToStateDict:
+ """Returns dict of checkrun -> [conclusion, url]"""
+ if self.conclusions is not None:
+ return self.conclusions
+ orig_last_commit = self.last_commit()
+
+ def get_pr_next_check_runs(
+ edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any
+ ) -> Any:
+ rc = gh_graphql(
+ GH_GET_PR_NEXT_CHECK_RUNS,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None,
+ cr_cursor=checkruns["pageInfo"]["endCursor"],
+ )
+ last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][
+ -1
+ ]["commit"]
+ checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"]
+ return checkruns
+
+ def get_pr_next_checksuites(checksuites: Any) -> Any:
+ rc = gh_graphql(
+ GH_GET_PR_NEXT_CHECKSUITES,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cursor=checksuites["edges"][-1]["cursor"],
+ )
+ info = rc["data"]["repository"]["pullRequest"]
+ last_commit = info["commits"]["nodes"][-1]["commit"]
+ if last_commit["oid"] != orig_last_commit["oid"]:
+ raise RuntimeError("Last commit changed on PR")
+ return last_commit["checkSuites"]
+
+ checksuites = orig_last_commit["checkSuites"]
+
+ self.conclusions = add_workflow_conclusions(
+ checksuites, get_pr_next_check_runs, get_pr_next_checksuites
+ )
+
+ # Append old style statuses(like ones populated by CircleCI or EasyCLA) to conclusions
+ if orig_last_commit["status"] and orig_last_commit["status"]["contexts"]:
+ for status in orig_last_commit["status"]["contexts"]:
+ name = status["context"]
+ self.conclusions[name] = JobCheckState(
+ name,
+ status["targetUrl"],
+ status["state"],
+ classification=None,
+ job_id=None,
+ title=None,
+ summary=None,
+ )
+
+ return self.conclusions
+
+ def get_authors(self) -> Dict[str, str]:
+ rc = {}
+ for idx in range(len(self._fetch_authors())):
+ rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
+
+ return rc
+
+ def get_author(self) -> str:
+ authors = self.get_authors()
+ if len(authors) == 1:
+ return next(iter(authors.values()))
+ creator = self.get_pr_creator_login()
+ # If PR creator is not among authors
+ # Assume it was authored by first commit author
+ if creator not in authors:
+ return self.get_committer_author(0)
+ return authors[creator]
+
+ def get_title(self) -> str:
+ return cast(str, self.info["title"])
+
+ def get_body(self) -> str:
+ return cast(str, self.info["body"])
+
+ def get_merge_commit(self) -> Optional[str]:
+ mc = self.info["mergeCommit"]
+ return mc["oid"] if mc is not None else None
+
+ def get_pr_url(self) -> str:
+ return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
+
+ @staticmethod
+ def _comment_from_node(node: Any) -> GitHubComment:
+ editor = node["editor"]
+ return GitHubComment(
+ body_text=node["bodyText"],
+ created_at=node["createdAt"] if "createdAt" in node else "",
+ author_login=node["author"]["login"],
+ author_association=node["authorAssociation"],
+ editor_login=editor["login"] if editor else None,
+ database_id=node["databaseId"],
+ url=node["url"],
+ )
+
+ def get_comments(self) -> List[GitHubComment]:
+ if self.comments is not None:
+ return self.comments
+ self.comments = []
+ info = self.info["comments"]
+ # Do not try to fetch more than 10K comments
+ for _ in range(100):
+ self.comments = [
+ self._comment_from_node(node) for node in info["nodes"]
+ ] + self.comments
+ if not info["pageInfo"]["hasPreviousPage"]:
+ break
+ rc = gh_graphql(
+ GH_GET_PR_PREV_COMMENTS,
+ name=self.project,
+ owner=self.org,
+ number=self.pr_num,
+ cursor=info["pageInfo"]["startCursor"],
+ )
+ info = rc["data"]["repository"]["pullRequest"]["comments"]
+ return self.comments
+
+ def get_last_comment(self) -> GitHubComment:
+ return self._comment_from_node(self.info["comments"]["nodes"][-1])
+
+ def get_comment_by_id(self, database_id: int) -> GitHubComment:
+ if self.comments is None:
+ # Fastpath - try searching in partial prefetched comments
+ for node in self.info["comments"]["nodes"]:
+ comment = self._comment_from_node(node)
+ if comment.database_id == database_id:
+ return comment
+
+ for comment in self.get_comments():
+ if comment.database_id == database_id:
+ return comment
+
+ # The comment could have actually been a review left on the PR (the message written alongside the review).
+ # (This is generally done to trigger the merge right when a comment is left)
+ # Check those review comments to see if one of those was the comment in question.
+ for node in self.info["reviews"]["nodes"]:
+ # These review comments contain all the fields regular comments need
+ comment = self._comment_from_node(node)
+ if comment.database_id == database_id:
+ return comment
+
+ raise RuntimeError(f"Comment with id {database_id} not found")
+
+ def get_diff_revision(self) -> Optional[str]:
+ rc = RE_DIFF_REV.search(self.get_body())
+ return rc.group(1) if rc is not None else None
+
+ def has_internal_changes(self) -> bool:
+ checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
+ if self.get_diff_revision() is None:
+ return False
+ checks = self.get_checkrun_conclusions()
+ if checks is None or checkrun_name not in checks:
+ return False
+ return checks[checkrun_name].status != "SUCCESS"
+
+ def has_no_connected_diff(self) -> bool:
+ checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
+ checks = self.get_checkrun_conclusions()
+ if checks is None or checkrun_name not in checks:
+ return False
+ return checks[checkrun_name].title == HAS_NO_CONNECTED_DIFF_TITLE
+
+ def merge_ghstack_into(
+ self,
+ repo: GitRepo,
+ skip_mandatory_checks: bool,
+ comment_id: Optional[int] = None,
+ skip_all_rule_checks: bool = False,
+ ) -> List["GitHubPR"]:
+ assert self.is_ghstack_pr()
+ ghstack_prs = get_ghstack_prs(
+ repo, self, open_only=False
+ ) # raises error if out of sync
+ pr_dependencies = []
+ for pr, rev in ghstack_prs:
+ if pr.is_closed():
+ pr_dependencies.append(pr)
+ continue
+
+ commit_msg = pr.gen_commit_message(
+ filter_ghstack=True, ghstack_deps=pr_dependencies
+ )
+ if pr.pr_num != self.pr_num and not skip_all_rule_checks:
+ # Raises exception if matching rule is not found
+ find_matching_merge_rule(
+ pr,
+ repo,
+ skip_mandatory_checks=skip_mandatory_checks,
+ skip_internal_checks=can_skip_internal_checks(self, comment_id),
+ )
+ repo.cherry_pick(rev)
+ repo.amend_commit_message(commit_msg)
+ pr_dependencies.append(pr)
+ return [x for x, _ in ghstack_prs if not x.is_closed()]
+
+ def gen_commit_message(
+ self,
+ filter_ghstack: bool = False,
+ ghstack_deps: Optional[List["GitHubPR"]] = None,
+ ) -> str:
+ """Fetches title and body from PR description
+ adds reviewed by, pull request resolved and optionally
+ filters out ghstack info"""
+ # Adding the url here makes it clickable within the Github UI
+ approved_by_urls = ", ".join(
+ prefix_with_github_url(login) for login in self.get_approved_by()
+ )
+ # Remove "cc: " line from the message body
+ msg_body = re.sub(RE_PR_CC_LINE, "", self.get_body())
+ if filter_ghstack:
+ msg_body = re.sub(RE_GHSTACK_DESC, "", msg_body)
+ msg = self.get_title() + f" (#{self.pr_num})\n\n"
+ msg += msg_body
+
+ # Mention PR co-authors
+ for author_login, author_name in self.get_authors().items():
+ if author_login != self.get_pr_creator_login():
+ msg += f"\nCo-authored-by: {author_name}"
+
+ msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
+ msg += f"Approved by: {approved_by_urls}\n"
+ if ghstack_deps:
+ msg += f"ghstack dependencies: {', '.join([f'#{pr.pr_num}' for pr in ghstack_deps])}\n"
+ return msg
+
+ def add_numbered_label(self, label_base: str, dry_run: bool) -> None:
+ labels = self.get_labels() if self.labels is not None else []
+ full_label = label_base
+ count = 0
+ for label in labels:
+ if label_base in label:
+ count += 1
+ full_label = f"{label_base}X{count}"
+ gh_add_labels(self.org, self.project, self.pr_num, [full_label], dry_run)
+
+ def merge_into(
+ self,
+ repo: GitRepo,
+ *,
+ skip_mandatory_checks: bool = False,
+ dry_run: bool = False,
+ comment_id: Optional[int] = None,
+ ignore_current_checks: Optional[List[str]] = None,
+ ) -> None:
+ # Raises exception if matching rule is not found
+ (
+ merge_rule,
+ pending_checks,
+ failed_checks,
+ ignorable_checks,
+ ) = find_matching_merge_rule(
+ self,
+ repo,
+ skip_mandatory_checks=skip_mandatory_checks,
+ skip_internal_checks=can_skip_internal_checks(self, comment_id),
+ ignore_current_checks=ignore_current_checks,
+ )
+ additional_merged_prs = self.merge_changes(
+ repo, skip_mandatory_checks, comment_id
+ )
+
+ repo.push(self.default_branch(), dry_run)
+ if not dry_run:
+ self.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
+ for pr in additional_merged_prs:
+ pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
+
+ if comment_id and self.pr_num:
+ # When the merge process reaches this part, we can assume that the commit
+ # has been successfully pushed to trunk
+ merge_commit_sha = repo.rev_parse(name=REMOTE_MAIN_BRANCH)
+
+ # Finally, upload the record to Rockset. The list of pending and failed
+ # checks are at the time of the merge
+ save_merge_record(
+ collection=ROCKSET_MERGES_COLLECTION,
+ comment_id=comment_id,
+ pr_num=self.pr_num,
+ owner=self.org,
+ project=self.project,
+ author=self.get_author(),
+ pending_checks=pending_checks,
+ failed_checks=failed_checks,
+ ignore_current_checks=ignorable_checks.get("IGNORE_CURRENT_CHECK", []),
+ broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
+ flaky_checks=ignorable_checks.get("FLAKY", []),
+ unstable_checks=ignorable_checks.get("UNSTABLE", []),
+ last_commit_sha=self.last_commit().get("oid", ""),
+ merge_base_sha=self.get_merge_base(),
+ merge_commit_sha=merge_commit_sha,
+ is_failed=False,
+ dry_run=dry_run,
+ skip_mandatory_checks=skip_mandatory_checks,
+ ignore_current=bool(ignore_current_checks),
+ workspace=ROCKSET_MERGES_WORKSPACE,
+ )
+ else:
+ print("Missing comment ID or PR number, couldn't upload to Rockset")
+
+ def merge_changes(
+ self,
+ repo: GitRepo,
+ skip_mandatory_checks: bool = False,
+ comment_id: Optional[int] = None,
+ branch: Optional[str] = None,
+ skip_all_rule_checks: bool = False,
+ ) -> List["GitHubPR"]:
+ """
+ :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
+ """
+ branch_to_merge_into = self.default_branch() if branch is None else branch
+ if repo.current_branch() != branch_to_merge_into:
+ repo.checkout(branch_to_merge_into)
+ if not self.is_ghstack_pr():
+ msg = self.gen_commit_message()
+ pr_branch_name = f"__pull-request-{self.pr_num}__init__"
+ repo.fetch(f"pull/{self.pr_num}/head", pr_branch_name)
+ repo._run_git("merge", "--squash", pr_branch_name)
+ repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
+ return []
+ else:
+ return self.merge_ghstack_into(
+ repo,
+ skip_mandatory_checks,
+ comment_id=comment_id,
+ skip_all_rule_checks=skip_all_rule_checks,
+ )
+
+
+class MergeRuleFailedError(RuntimeError):
+ def __init__(self, message: str, rule: Optional["MergeRule"] = None) -> None:
+ super().__init__(message)
+ self.rule = rule
+
+
+class MandatoryChecksMissingError(MergeRuleFailedError):
+ pass
+
+
+class PostCommentError(Exception):
+ pass
+
+
+@dataclass
+class MergeRule:
+ name: str
+ patterns: List[str]
+ approved_by: List[str]
+ mandatory_checks_name: Optional[List[str]]
+ ignore_flaky_failures: bool = True
+
+
+def gen_new_issue_link(
+ org: str, project: str, labels: List[str], template: str = "bug-report.yml"
+) -> str:
+ labels_str = ",".join(labels)
+ return (
+ f"https://github.com/{org}/{project}/issues/new?"
+ f"labels={urllib.parse.quote(labels_str)}&"
+ f"template={urllib.parse.quote(template)}"
+ )
+
+
+def read_merge_rules(
+ repo: Optional[GitRepo], org: str, project: str
+) -> List[MergeRule]:
+ """Returns the list of all merge rules for the repo or project.
+
+ NB: this function is used in Meta-internal workflows, see the comment
+ at the top of this file for details.
+ """
+ repo_relative_rules_path = MERGE_RULE_PATH
+ if repo is None:
+ json_data = gh_fetch_url(
+ f"https://api.github.com/repos/{org}/{project}/contents/{repo_relative_rules_path}",
+ headers={"Accept": "application/vnd.github.v3+json"},
+ reader=json.load,
+ )
+ content = base64.b64decode(json_data["content"])
+ return [MergeRule(**x) for x in yaml.safe_load(content)]
+ else:
+ rules_path = Path(repo.repo_dir) / repo_relative_rules_path
+ if not rules_path.exists():
+ print(f"{rules_path} does not exist, returning empty rules")
+ return []
+ with open(rules_path) as fp:
+ rc = yaml.safe_load(fp)
+ return [MergeRule(**x) for x in rc]
+
+
+def find_matching_merge_rule(
+ pr: GitHubPR,
+ repo: Optional[GitRepo] = None,
+ skip_mandatory_checks: bool = False,
+ skip_internal_checks: bool = False,
+ ignore_current_checks: Optional[List[str]] = None,
+) -> Tuple[
+ MergeRule,
+ List[Tuple[str, Optional[str], Optional[int]]],
+ List[Tuple[str, Optional[str], Optional[int]]],
+ Dict[str, List[Any]],
+]:
+ """
+ Returns merge rule matching to this pr together with the list of associated pending
+ and failing jobs OR raises an exception.
+
+ NB: this function is used in Meta-internal workflows, see the comment at the top of
+ this file for details.
+ """
+ changed_files = pr.get_changed_files()
+ approved_by = set(pr.get_approved_by())
+
+ issue_link = gen_new_issue_link(
+ org=pr.org,
+ project=pr.project,
+ labels=["module: ci"],
+ )
+ reject_reason = f"No rule found to match PR. Please [report]{issue_link} this issue to DevX team."
+
+ rules = read_merge_rules(repo, pr.org, pr.project)
+ if not rules:
+ reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
+ raise RuntimeError(reject_reason)
+
+ checks = pr.get_checkrun_conclusions()
+ checks = get_classifications(
+ pr.pr_num,
+ pr.project,
+ checks,
+ ignore_current_checks=ignore_current_checks,
+ )
+
+ # This keeps the list of all approvers that could stamp the change
+ all_rule_approvers = {}
+
+ # PRs can fail multiple merge rules, but it only needs to pass one rule to be approved.
+ # If it fails all rules, we need to find the rule that it came closest to passing and report
+ # that to the dev.
+ #
+ # reject_reason_score ranks rules by relevancy. The higher the score, the more relevant the
+ # rule & rejection reason, and we only care about the most relevant rule/reason
+ #
+ # reject_reason_score intrepretation:
+ # Score 0 to 10K - how many files rule matched
+ # Score 10K - matched all files, but no overlapping approvers
+ # Score 20K - matched all files and approvers, but mandatory checks are pending
+ # Score 30k - Matched all files and approvers, but mandatory checks failed
+ reject_reason_score = 0
+ for rule in rules:
+ rule_name = rule.name
+ patterns_re = patterns_to_regex(rule.patterns)
+ non_matching_files = []
+
+ # Does this rule apply to all the files?
+ for fname in changed_files:
+ if not patterns_re.match(fname):
+ non_matching_files.append(fname)
+ if len(non_matching_files) > 0:
+ num_matching_files = len(changed_files) - len(non_matching_files)
+ if num_matching_files > reject_reason_score:
+ reject_reason_score = num_matching_files
+ reject_reason = "\n".join(
+ (
+ f"Not all files match rule `{rule_name}`.",
+ f"{num_matching_files} files matched, but there are still non-matching files:",
+ f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}",
+ )
+ )
+ continue
+
+ # If rule needs approvers but PR has not been reviewed, skip it
+ if len(rule.approved_by) > 0 and len(approved_by) == 0:
+ if reject_reason_score < 10000:
+ reject_reason_score = 10000
+ reject_reason = f"PR #{pr.pr_num} has not been reviewed yet"
+ continue
+
+ # Does the PR have the required approvals for this rule?
+ rule_approvers = set()
+ for approver in rule.approved_by:
+ if "/" in approver:
+ org, name = approver.split("/")
+ rule_approvers.update(gh_get_team_members(org, name))
+ else:
+ rule_approvers.add(approver)
+ approvers_intersection = approved_by.intersection(rule_approvers)
+ # If rule requires approvers but they aren't the ones that reviewed PR
+ if len(approvers_intersection) == 0 and len(rule_approvers) > 0:
+ # Less than or equal is intentionally used here to gather all potential
+ # approvers
+ if reject_reason_score <= 10000:
+ reject_reason_score = 10000
+
+ all_rule_approvers[rule.name] = rule.approved_by
+ # Prepare the reject reason
+ all_rule_approvers_msg = [
+ f"- {name} ({', '.join(approved_by[:5])}{', ...' if len(approved_by) > 5 else ''})"
+ for name, approved_by in all_rule_approvers.items()
+ ]
+
+ reject_reason = "Approvers from one of the following sets are needed:\n"
+ reject_reason += "\n".join(all_rule_approvers_msg)
+
+ continue
+
+ # Does the PR pass the checks required by this rule?
+ mandatory_checks = (
+ rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
+ )
+ required_checks = list(
+ filter(
+ lambda x: ("EasyCLA" in x)
+ or ("Facebook CLA Check" in x)
+ or not skip_mandatory_checks,
+ mandatory_checks,
+ )
+ )
+ pending_checks, failed_checks, _ = categorize_checks(
+ checks,
+ required_checks,
+ ok_failed_checks_threshold=(
+ IGNORABLE_FAILED_CHECKS_THESHOLD if rule.ignore_flaky_failures else 0
+ ),
+ )
+
+ # categorize_checks assumes all tests are required if required_checks is empty.
+ # this is a workaround as we want to keep that behavior for categorize_checks
+ # generally.
+ if not required_checks:
+ pending_checks = []
+ failed_checks = []
+
+ hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}"
+ if len(failed_checks) > 0:
+ if reject_reason_score < 30000:
+ reject_reason_score = 30000
+ reject_reason = "\n".join(
+ (
+ f"{len(failed_checks)} mandatory check(s) failed. The first few are:",
+ *checks_to_markdown_bullets(failed_checks),
+ "",
+ f"Dig deeper by [viewing the failures on hud]({hud_link})",
+ )
+ )
+ continue
+ elif len(pending_checks) > 0:
+ if reject_reason_score < 20000:
+ reject_reason_score = 20000
+ reject_reason = "\n".join(
+ (
+ f"{len(pending_checks)} mandatory check(s) are pending/not yet run. The first few are:",
+ *checks_to_markdown_bullets(pending_checks),
+ "",
+ f"Dig deeper by [viewing the pending checks on hud]({hud_link})",
+ )
+ )
+ continue
+
+ if not skip_internal_checks and pr.has_internal_changes():
+ raise RuntimeError(
+ "This PR has internal changes and must be landed via Phabricator"
+ )
+
+ # Categorize all checks when skip_mandatory_checks (force merge) is set. Do it here
+ # where the list of checks is readily available. These records will be saved into
+ # Rockset merge records
+ (
+ pending_mandatory_checks,
+ failed_mandatory_checks,
+ ignorable_checks,
+ ) = categorize_checks(
+ checks,
+ [],
+ ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
+ )
+ return (
+ rule,
+ pending_mandatory_checks,
+ failed_mandatory_checks,
+ ignorable_checks,
+ )
+
+ if reject_reason_score == 20000:
+ raise MandatoryChecksMissingError(reject_reason, rule)
+ raise MergeRuleFailedError(reject_reason, rule)
+
+
+def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
+ return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
+
+
+def checks_to_markdown_bullets(
+ checks: List[Tuple[str, Optional[str], Optional[int]]]
+) -> List[str]:
+ return [
+ f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
+ ]
+
+
+@retries_decorator()
+def save_merge_record(
+ collection: str,
+ comment_id: int,
+ pr_num: int,
+ owner: str,
+ project: str,
+ author: str,
+ pending_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ failed_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ flaky_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ unstable_checks: List[Tuple[str, Optional[str], Optional[int]]],
+ last_commit_sha: str,
+ merge_base_sha: str,
+ merge_commit_sha: str = "",
+ is_failed: bool = False,
+ dry_run: bool = False,
+ skip_mandatory_checks: bool = False,
+ ignore_current: bool = False,
+ error: str = "",
+ workspace: str = "commons",
+) -> None:
+ """
+ This saves the merge records into Rockset, so we can query them (for fun and profit)
+ """
+ if dry_run:
+ # Decide not to save the record to Rockset if dry-run is set to not pollute
+ # the collection
+ return
+
+ try:
+ import rockset # type: ignore[import]
+
+ # Prepare the record to be written into Rockset
+ data = [
+ {
+ "comment_id": comment_id,
+ "pr_num": pr_num,
+ "owner": owner,
+ "project": project,
+ "author": author,
+ "pending_checks": pending_checks,
+ "failed_checks": failed_checks,
+ "ignore_current_checks": ignore_current_checks,
+ "broken_trunk_checks": broken_trunk_checks,
+ "flaky_checks": flaky_checks,
+ "unstable_checks": unstable_checks,
+ "last_commit_sha": last_commit_sha,
+ "merge_base_sha": merge_base_sha,
+ "merge_commit_sha": merge_commit_sha,
+ "is_failed": is_failed,
+ "skip_mandatory_checks": skip_mandatory_checks,
+ "ignore_current": ignore_current,
+ "error": error,
+ }
+ ]
+
+ client = rockset.RocksetClient(
+ host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
+ )
+ client.Documents.add_documents(
+ collection=collection,
+ data=data,
+ workspace=workspace,
+ )
+
+ except ModuleNotFoundError:
+ print("Rockset is missing, no record will be saved")
+ return
+
+
+@retries_decorator(rc=[])
+def get_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
+ query = f"""
+SELECT
+ w.name as workflow_name,
+ j.id,
+ j.name,
+ j.conclusion,
+ j.completed_at,
+ j.html_url,
+ j.head_sha,
+ j.torchci_classification.captures as failure_captures,
+ LENGTH(j.steps) as steps,
+FROM
+ commons.workflow_job j join commons.workflow_run w on w.id = j.run_id
+where
+ j.head_sha in ('{head_sha}','{merge_base}')
+"""
+ try:
+ import rockset # type: ignore[import]
+
+ res = rockset.RocksetClient(
+ host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
+ ).sql(query)
+ return cast(List[Dict[str, Any]], res.results)
+ except ModuleNotFoundError:
+ print("Could not use RockSet as rocket dependency is missing")
+ return []
+
+
+@retries_decorator()
+def get_drci_classifications(pr_num: int, project: str = "pytorch") -> Any:
+ """
+ Query HUD API to find similar failures to decide if they are flaky
+ """
+ # NB: This doesn't work internally atm because this requires making an
+ # external API call to HUD
+ failures = gh_fetch_url(
+ f"https://hud.pytorch.org/api/drci/drci?prNumber={pr_num}",
+ data=f"repo={project}",
+ headers={
+ "Authorization": os.getenv("DRCI_BOT_KEY", ""),
+ "Accept": "application/vnd.github.v3+json",
+ },
+ method="POST",
+ reader=json.load,
+ )
+
+ return failures.get(str(pr_num), {}) if failures else {}
+
+
+REMOVE_JOB_NAME_SUFFIX_REGEX = re.compile(r", [0-9]+, [0-9]+, .+\)$")
+
+
+def remove_job_name_suffix(name: str, replacement: str = ")") -> str:
+ return re.sub(REMOVE_JOB_NAME_SUFFIX_REGEX, replacement, name)
+
+
+def is_broken_trunk(
+ name: str,
+ drci_classifications: Any,
+) -> bool:
+ if not name or not drci_classifications:
+ return False
+
+ # Consult the list of broken trunk failures from Dr.CI
+ return any(
+ name == broken_trunk["name"]
+ for broken_trunk in drci_classifications.get("BROKEN_TRUNK", [])
+ )
+
+
+def is_flaky(
+ name: str,
+ drci_classifications: Any,
+) -> bool:
+ if not name or not drci_classifications:
+ return False
+
+ # Consult the list of flaky failures from Dr.CI
+ return any(name == flaky["name"] for flaky in drci_classifications.get("FLAKY", []))
+
+
+def is_invalid_cancel(
+ name: str,
+ conclusion: Optional[str],
+ drci_classifications: Any,
+) -> bool:
+ """
+ After https://github.com/pytorch/test-infra/pull/4579, invalid cancelled
+ signals have been removed from HUD and Dr.CI. The same needs to be done
+ here for consistency
+ """
+ if (
+ not name
+ or not drci_classifications
+ or not conclusion
+ or conclusion.upper() != "CANCELLED"
+ ):
+ return False
+
+ # If a job is cancelled and not listed as a failure by Dr.CI, it's an
+ # invalid signal and can be ignored
+ return all(
+ name != failure["name"] for failure in drci_classifications.get("FAILED", [])
+ )
+
+
+def get_classifications(
+ pr_num: int,
+ project: str,
+ checks: Dict[str, JobCheckState],
+ ignore_current_checks: Optional[List[str]],
+) -> Dict[str, JobCheckState]:
+ # Get the failure classification from Dr.CI, which is the source of truth
+ # going forward. It's preferable to try calling Dr.CI API directly first
+ # to get the latest results as well as update Dr.CI PR comment
+ drci_classifications = get_drci_classifications(pr_num=pr_num, project=project)
+
+ def get_readable_drci_results(drci_classifications: Any) -> str:
+ try:
+ s = f"From Dr.CI API ({pr_num}):\n"
+ for classification, jobs in drci_classifications.items():
+ s += f" {classification}: \n"
+ for job in jobs:
+ s += f" {job['id']} {job['name']}\n"
+ return s
+ except Exception:
+ return f"From Dr.CI API: {json.dumps(drci_classifications)}"
+
+ print(get_readable_drci_results(drci_classifications))
+
+ # NB: if the latest results from Dr.CI is not available, i.e. when calling from
+ # SandCastle, we fallback to any results we can find on Dr.CI check run summary
+ if (
+ not drci_classifications
+ and DRCI_CHECKRUN_NAME in checks
+ and checks[DRCI_CHECKRUN_NAME]
+ and checks[DRCI_CHECKRUN_NAME].summary
+ ):
+ drci_summary = checks[DRCI_CHECKRUN_NAME].summary
+ try:
+ print(f"From Dr.CI checkrun summary: {drci_summary}")
+ drci_classifications = json.loads(str(drci_summary))
+ except json.JSONDecodeError as error:
+ warn("Invalid Dr.CI checkrun summary")
+ drci_classifications = {}
+
+ checks_with_classifications = checks.copy()
+ for name, check in checks.items():
+ if check.status == "SUCCESS" or check.status == "NEUTRAL":
+ continue
+
+ if "unstable" in name:
+ checks_with_classifications[name] = JobCheckState(
+ check.name,
+ check.url,
+ check.status,
+ "UNSTABLE",
+ check.job_id,
+ check.title,
+ check.summary,
+ )
+ continue
+
+ # NB: It's important to note that when it comes to ghstack and broken trunk classification,
+ # Dr.CI uses the base of the whole stack
+ if is_broken_trunk(name, drci_classifications):
+ checks_with_classifications[name] = JobCheckState(
+ check.name,
+ check.url,
+ check.status,
+ "BROKEN_TRUNK",
+ check.job_id,
+ check.title,
+ check.summary,
+ )
+ continue
+
+ elif is_flaky(name, drci_classifications):
+ checks_with_classifications[name] = JobCheckState(
+ check.name,
+ check.url,
+ check.status,
+ "FLAKY",
+ check.job_id,
+ check.title,
+ check.summary,
+ )
+ continue
+
+ elif is_invalid_cancel(name, check.status, drci_classifications):
+ # NB: Create a new category here for invalid cancelled signals because
+ # there are usually many of them when they happen. So, they shouldn't
+ # be counted toward ignorable failures threshold
+ checks_with_classifications[name] = JobCheckState(
+ check.name,
+ check.url,
+ check.status,
+ "INVALID_CANCEL",
+ check.job_id,
+ check.title,
+ check.summary,
+ )
+ continue
+
+ if ignore_current_checks is not None and name in ignore_current_checks:
+ checks_with_classifications[name] = JobCheckState(
+ check.name,
+ check.url,
+ check.status,
+ "IGNORE_CURRENT_CHECK",
+ check.job_id,
+ check.title,
+ check.summary,
+ )
+
+ return checks_with_classifications
+
+
+def filter_checks_with_lambda(
+ checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
+) -> List[JobCheckState]:
+ return [check for check in checks.values() if status_filter(check.status)]
+
+
+def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
+ commit_sha = pr.get_merge_commit()
+ if commit_sha is not None:
+ return commit_sha
+ commits = repo.commits_resolving_gh_pr(pr.pr_num)
+ if len(commits) == 0:
+ raise PostCommentError("Can't find any commits resolving PR")
+ return commits[0]
+
+
+def validate_revert(
+ repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
+) -> Tuple[str, str]:
+ comment = (
+ pr.get_last_comment()
+ if comment_id is None
+ else pr.get_comment_by_id(comment_id)
+ )
+ if comment.editor_login is not None:
+ raise PostCommentError("Don't want to revert based on edited command")
+ author_association = comment.author_association
+ author_login = comment.author_login
+ allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]
+ # For some reason, one can not be a member of private repo, only CONTRIBUTOR
+ if pr.is_base_repo_private():
+ allowed_reverters.append("CONTRIBUTOR")
+ if author_association not in allowed_reverters:
+ raise PostCommentError(
+ f"Will not revert as @{author_login} is not one of "
+ f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
+ )
+
+ # Raises exception if matching rule is not found, but ignores all status checks
+ find_matching_merge_rule(
+ pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
+ )
+ commit_sha = get_pr_commit_sha(repo, pr)
+ return (author_login, commit_sha)
+
+
+def get_ghstack_dependent_prs(
+ repo: GitRepo, pr: GitHubPR, only_closed: bool = True
+) -> List[Tuple[str, GitHubPR]]:
+ """
+ Get the PRs in the stack that are above this PR (inclusive).
+ Throws error if stack have branched or original branches are gone
+ """
+ assert pr.is_ghstack_pr()
+ orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
+ rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
+ if len(rev_list) == 0:
+ raise RuntimeError(
+ f"PR {pr.pr_num} does not have any revisions associated with it"
+ )
+ skip_len = len(rev_list) - 1
+ for branch in repo.branches_containing_ref(orig_ref):
+ candidate = repo.revlist(f"{pr.default_branch()}..{branch}")
+ # Pick longest candidate
+ if len(candidate) > len(rev_list):
+ candidate, rev_list = rev_list, candidate
+ # Validate that candidate always ends rev-list
+ if rev_list[-len(candidate) :] != candidate:
+ raise RuntimeError(
+ f"Branch {branch} revlist {', '.join(candidate)} is not a subset of {', '.join(rev_list)}"
+ )
+ # Remove commits original PR depends on
+ if skip_len > 0:
+ rev_list = rev_list[:-skip_len]
+ rc: List[Tuple[str, GitHubPR]] = []
+ for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
+ if not pr_.is_closed():
+ if not only_closed:
+ rc.append(("", pr_))
+ continue
+ commit_sha = get_pr_commit_sha(repo, pr_)
+ rc.append((commit_sha, pr_))
+ return rc
+
+
+def do_revert_prs(
+ repo: GitRepo,
+ shas_and_prs: List[Tuple[str, GitHubPR]],
+ *,
+ author_login: str,
+ extra_msg: str = "",
+ skip_internal_checks: bool = False,
+ dry_run: bool = False,
+) -> None:
+ # Prepare and push revert commits
+ commit_shas: List[str] = []
+ for commit_sha, pr in shas_and_prs:
+ revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
+ revert_msg += extra_msg
+ repo.checkout(pr.default_branch())
+ repo.revert(commit_sha)
+ msg = repo.commit_message("HEAD")
+ msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
+ msg += revert_msg
+ repo.amend_commit_message(msg)
+ repo.push(shas_and_prs[0][1].default_branch(), dry_run)
+
+ # Comment/reopen PRs
+ for commit_sha, pr in shas_and_prs:
+ revert_message = (
+ f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
+ )
+ if (
+ pr.has_internal_changes()
+ and not pr.has_no_connected_diff()
+ and not skip_internal_checks
+ ):
+ revert_message += "\n:warning: This PR might contain internal changes"
+ revert_message += "\ncc: @pytorch/pytorch-dev-infra"
+ gh_post_pr_comment(
+ pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
+ )
+
+ pr.add_numbered_label("reverted", dry_run)
+ if not dry_run:
+ gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
+ gh_update_pr_state(pr.org, pr.project, pr.pr_num)
+
+
+def try_revert(
+ repo: GitRepo,
+ pr: GitHubPR,
+ *,
+ dry_run: bool = False,
+ comment_id: Optional[int] = None,
+ reason: Optional[str] = None,
+) -> None:
+ try:
+ author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
+ except PostCommentError as e:
+ gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
+ return
+
+ extra_msg = f" due to {reason}" if reason is not None else ""
+ extra_msg += (
+ f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
+ if comment_id is not None
+ else "\n"
+ )
+ shas_and_prs = [(commit_sha, pr)]
+ if pr.is_ghstack_pr():
+ try:
+ shas_and_prs = get_ghstack_dependent_prs(repo, pr)
+ prs_to_revert = " ".join([t[1].get_pr_url() for t in shas_and_prs])
+ print(f"About to stack of PRs: {prs_to_revert}")
+ except Exception as e:
+ print(
+ f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert"
+ )
+
+ do_revert_prs(
+ repo,
+ shas_and_prs,
+ author_login=author_login,
+ extra_msg=extra_msg,
+ dry_run=dry_run,
+ skip_internal_checks=can_skip_internal_checks(pr, comment_id),
+ )
+
+
+def prefix_with_github_url(suffix_str: str) -> str:
+ return f"https://github.com/{suffix_str}"
+
+
+def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
+ if skip_mandatory_checks:
+ return
+ response = cast(
+ Dict[str, Any],
+ gh_fetch_json_list(
+ "https://api.github.com/search/issues",
+ params={"q": f'repo:{org}/{project} is:open is:issue label:"ci: sev"'},
+ ),
+ )
+ if response["total_count"] != 0:
+ for item in response["items"]:
+ if "MERGE BLOCKING" in item["body"]:
+ raise RuntimeError(
+ "Not merging any PRs at the moment because there is a "
+ + "merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: \n"
+ + f"{item['html_url']}"
+ )
+ return
+
+
+def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
+ return len(list(filter(pattern.match, labels))) > 0
+
+
+def categorize_checks(
+ check_runs: JobNameToStateDict,
+ required_checks: List[str],
+ ok_failed_checks_threshold: Optional[int] = None,
+) -> Tuple[
+ List[Tuple[str, Optional[str], Optional[int]]],
+ List[Tuple[str, Optional[str], Optional[int]]],
+ Dict[str, List[Any]],
+]:
+ """
+ Categories all jobs into the list of pending and failing jobs. All known flaky
+ failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
+ is not set (unlimited)
+ """
+ pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
+ failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
+
+ # ok_failed_checks is used with ok_failed_checks_threshold while ignorable_failed_checks
+ # is used to keep track of all ignorable failures when saving the merge record on Rockset
+ ok_failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
+ ignorable_failed_checks: Dict[str, List[Any]] = defaultdict(list)
+
+ # If required_checks is not set or empty, consider all names are relevant
+ relevant_checknames = [
+ name
+ for name in check_runs.keys()
+ if not required_checks or any(x in name for x in required_checks)
+ ]
+
+ for checkname in required_checks:
+ if all(checkname not in x for x in check_runs.keys()):
+ pending_checks.append((checkname, None, None))
+
+ for checkname in relevant_checknames:
+ status = check_runs[checkname].status
+ url = check_runs[checkname].url
+ classification = check_runs[checkname].classification
+ job_id = check_runs[checkname].job_id
+
+ if status is None and classification != "UNSTABLE":
+ # NB: No need to wait if the job classification is unstable as it would be
+ # ignored anyway. This is useful to not need to wait for scarce resources
+ # like ROCm, which is also frequently in unstable mode
+ pending_checks.append((checkname, url, job_id))
+ elif classification == "INVALID_CANCEL":
+ continue
+ elif not is_passing_status(check_runs[checkname].status):
+ target = (
+ ignorable_failed_checks[classification]
+ if classification
+ in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE")
+ else failed_checks
+ )
+ target.append((checkname, url, job_id))
+
+ if classification in ("BROKEN_TRUNK", "FLAKY", "UNSTABLE"):
+ ok_failed_checks.append((checkname, url, job_id))
+
+ if ok_failed_checks:
+ warn(
+ f"The following {len(ok_failed_checks)} checks failed but were likely due flakiness or broken trunk: "
+ + ", ".join([x[0] for x in ok_failed_checks])
+ + (
+ f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail"
+ if ok_failed_checks_threshold is not None
+ and len(ok_failed_checks) > ok_failed_checks_threshold
+ else ""
+ )
+ )
+
+ if (
+ ok_failed_checks_threshold is not None
+ and len(ok_failed_checks) > ok_failed_checks_threshold
+ ):
+ failed_checks = failed_checks + ok_failed_checks
+
+ # The list of ignorable_failed_checks is returned so that it can be saved into the Rockset merge record
+ return (pending_checks, failed_checks, ignorable_failed_checks)
+
+
+def merge(
+ pr: GitHubPR,
+ repo: GitRepo,
+ dry_run: bool = False,
+ skip_mandatory_checks: bool = False,
+ comment_id: Optional[int] = None,
+ timeout_minutes: int = 400,
+ stale_pr_days: int = 3,
+ ignore_current: bool = False,
+) -> None:
+ initial_commit_sha = pr.last_commit()["oid"]
+ pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
+ print(f"Attempting merge of {initial_commit_sha} ({pr_link})")
+
+ if MERGE_IN_PROGRESS_LABEL not in pr.get_labels():
+ gh_add_labels(pr.org, pr.project, pr.pr_num, [MERGE_IN_PROGRESS_LABEL], dry_run)
+
+ explainer = TryMergeExplainer(
+ skip_mandatory_checks,
+ pr.get_labels(),
+ pr.pr_num,
+ pr.org,
+ pr.project,
+ ignore_current,
+ )
+
+ # probably a bad name, but this is a list of current checks that should be
+ # ignored and is toggled by the --ignore-current flag
+ ignore_current_checks_info = []
+
+ if pr.is_ghstack_pr():
+ get_ghstack_prs(repo, pr) # raises error if out of sync
+
+ check_for_sev(pr.org, pr.project, skip_mandatory_checks)
+
+ if skip_mandatory_checks:
+ gh_post_pr_comment(
+ pr.org,
+ pr.project,
+ pr.pr_num,
+ explainer.get_merge_message(),
+ dry_run=dry_run,
+ )
+ return pr.merge_into(
+ repo,
+ dry_run=dry_run,
+ skip_mandatory_checks=skip_mandatory_checks,
+ comment_id=comment_id,
+ )
+
+ # Check for approvals
+ find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
+
+ if not has_required_labels(pr):
+ raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
+
+ if ignore_current:
+ checks = pr.get_checkrun_conclusions()
+ _, failing, _ = categorize_checks(
+ checks,
+ list(checks.keys()),
+ ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
+ )
+ ignore_current_checks_info = failing
+
+ gh_post_pr_comment(
+ pr.org,
+ pr.project,
+ pr.pr_num,
+ explainer.get_merge_message(ignore_current_checks_info),
+ dry_run=dry_run,
+ )
+
+ start_time = time.time()
+ last_exception = ""
+ elapsed_time = 0.0
+ ignore_current_checks = [
+ x[0] for x in ignore_current_checks_info
+ ] # convert to List[str] for convenience
+ while elapsed_time < timeout_minutes * 60:
+ check_for_sev(pr.org, pr.project, skip_mandatory_checks)
+ current_time = time.time()
+ elapsed_time = current_time - start_time
+ print(
+ f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
+ )
+ pr = GitHubPR(pr.org, pr.project, pr.pr_num)
+ if initial_commit_sha != pr.last_commit()["oid"]:
+ raise RuntimeError(
+ "New commits were pushed while merging. Please rerun the merge command."
+ )
+ try:
+ required_checks = []
+ failed_rule_message = None
+ ignore_flaky_failures = True
+ try:
+ find_matching_merge_rule(
+ pr, repo, ignore_current_checks=ignore_current_checks
+ )
+ except MandatoryChecksMissingError as ex:
+ if ex.rule is not None:
+ ignore_flaky_failures = ex.rule.ignore_flaky_failures
+ if ex.rule.mandatory_checks_name is not None:
+ required_checks = ex.rule.mandatory_checks_name
+ failed_rule_message = ex
+
+ checks = pr.get_checkrun_conclusions()
+ checks = get_classifications(
+ pr.pr_num,
+ pr.project,
+ checks,
+ ignore_current_checks=ignore_current_checks,
+ )
+ pending, failing, _ = categorize_checks(
+ checks,
+ required_checks
+ + [x for x in checks.keys() if x not in required_checks],
+ ok_failed_checks_threshold=(
+ IGNORABLE_FAILED_CHECKS_THESHOLD if ignore_flaky_failures else 0
+ ),
+ )
+ # HACK until GitHub will be better about surfacing those
+ startup_failures = filter_checks_with_lambda(
+ checks, lambda status: status == "STARTUP_FAILURE"
+ )
+ if len(startup_failures) > 0:
+ raise RuntimeError(
+ f"{len(startup_failures)} STARTUP failures reported, please check workflows syntax! "
+ + ", ".join(f"[{x.name}]({x.url})" for x in startup_failures[:5])
+ )
+ # END of HACK
+
+ if len(failing) > 0:
+ raise RuntimeError(
+ f"{len(failing)} jobs have failed, first few of them are: "
+ + ", ".join(f"[{x[0]}]({x[1]})" for x in failing[:5])
+ )
+ if len(pending) > 0:
+ if failed_rule_message is not None:
+ raise failed_rule_message
+ else:
+ raise MandatoryChecksMissingError(
+ f"Still waiting for {len(pending)} jobs to finish, "
+ + f"first few of them are: {', '.join(x[0] for x in pending[:5])}"
+ )
+
+ return pr.merge_into(
+ repo,
+ dry_run=dry_run,
+ skip_mandatory_checks=skip_mandatory_checks,
+ comment_id=comment_id,
+ ignore_current_checks=ignore_current_checks,
+ )
+ except MandatoryChecksMissingError as ex:
+ last_exception = str(ex)
+ print(
+ f"Merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} failed due to: {ex}. Retrying in 5 min"
+ )
+ time.sleep(5 * 60)
+ # Finally report timeout back
+ msg = f"Merged timed out after {timeout_minutes} minutes. Please contact the pytorch_dev_infra team."
+ msg += f"The last exception was: {last_exception}"
+ gh_add_labels(pr.org, pr.project, pr.pr_num, ["land-failed"], dry_run)
+ raise RuntimeError(msg)
+
+
+def main() -> None:
+ args = parse_args()
+ repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
+ org, project = repo.gh_owner_and_name()
+ pr = GitHubPR(org, project, args.pr_num)
+
+ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
+ exception = f"**Reason**: {e}"
+
+ failing_rule = None
+ if isinstance(e, MergeRuleFailedError):
+ failing_rule = e.rule.name if e.rule else None
+
+ internal_debugging = ""
+ run_url = os.getenv("GH_RUN_URL")
+ if run_url is not None:
+ # Hide this behind a collapsed bullet since it's not helpful to most devs
+ internal_debugging = "\n".join(
+ line
+ for line in (
+ "Details for Dev Infra team
",
+ f'Raised by workflow job\n',
+ f"Failing merge rule: {failing_rule}" if failing_rule else "",
+ " ",
+ )
+ if line
+ ) # ignore empty lines during the join
+
+ msg = "\n".join((f"## {title}", f"{exception}", "", f"{internal_debugging}"))
+
+ gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
+ import traceback
+
+ traceback.print_exc()
+
+ if args.revert:
+ try:
+ gh_post_pr_comment(
+ org,
+ project,
+ args.pr_num,
+ get_revert_message(org, project, pr.pr_num),
+ args.dry_run,
+ )
+ try_revert(
+ repo,
+ pr,
+ dry_run=args.dry_run,
+ comment_id=args.comment_id,
+ reason=args.reason,
+ )
+ except Exception as e:
+ handle_exception(e, f"Reverting PR {args.pr_num} failed")
+ return
+
+ if pr.is_closed():
+ gh_post_pr_comment(
+ org,
+ project,
+ args.pr_num,
+ f"Can't merge closed PR #{args.pr_num}",
+ dry_run=args.dry_run,
+ )
+ return
+
+ if pr.is_cross_repo() and pr.is_ghstack_pr():
+ gh_post_pr_comment(
+ org,
+ project,
+ args.pr_num,
+ "Cross-repo ghstack merges are not supported",
+ dry_run=args.dry_run,
+ )
+ return
+
+ if args.check_mergeability:
+ if pr.is_ghstack_pr():
+ get_ghstack_prs(repo, pr) # raises error if out of sync
+ pr.merge_changes(
+ repo,
+ skip_mandatory_checks=True,
+ skip_all_rule_checks=True,
+ )
+ return
+
+ if not args.force and pr.has_invalid_submodule_updates():
+ message = (
+ f"This PR updates submodules {', '.join(pr.get_changed_submodules())}\n"
+ )
+ message += '\nIf those updates are intentional, please add "submodule" keyword to PR title/description.'
+ gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run)
+ return
+ try:
+ merge(
+ pr,
+ repo,
+ dry_run=args.dry_run,
+ skip_mandatory_checks=args.force,
+ comment_id=args.comment_id,
+ ignore_current=args.ignore_current,
+ )
+ except Exception as e:
+ handle_exception(e)
+
+ if args.comment_id and args.pr_num:
+ # Finally, upload the record to Rockset, we don't have access to the
+ # list of pending and failed checks here, but they are not really
+ # needed at the moment
+ save_merge_record(
+ collection=ROCKSET_MERGES_COLLECTION,
+ comment_id=args.comment_id,
+ pr_num=args.pr_num,
+ owner=org,
+ project=project,
+ author=pr.get_author(),
+ pending_checks=[],
+ failed_checks=[],
+ ignore_current_checks=[],
+ broken_trunk_checks=[],
+ flaky_checks=[],
+ unstable_checks=[],
+ last_commit_sha=pr.last_commit().get("oid", ""),
+ merge_base_sha=pr.get_merge_base(),
+ is_failed=True,
+ dry_run=args.dry_run,
+ skip_mandatory_checks=args.force,
+ ignore_current=args.ignore_current,
+ error=str(e),
+ workspace=ROCKSET_MERGES_WORKSPACE,
+ )
+ else:
+ print("Missing comment ID or PR number, couldn't upload to Rockset")
+ finally:
+ if not args.check_mergeability:
+ gh_remove_label(
+ org, project, args.pr_num, MERGE_IN_PROGRESS_LABEL, args.dry_run
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/trymerge_explainer.py b/.github/scripts/trymerge_explainer.py
new file mode 100644
index 0000000000..4b472a4cb5
--- /dev/null
+++ b/.github/scripts/trymerge_explainer.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import re
+from typing import List, Optional, Pattern, Tuple
+
+
+BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
+
+CIFLOW_LABEL = re.compile(r"^ciflow/.+")
+CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
+
+OFFICE_HOURS_LINK = "https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours"
+CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team]({OFFICE_HOURS_LINK})"
+ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
+
+
+def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
+ return len(list(filter(pattern.match, labels))) > 0
+
+
+class TryMergeExplainer:
+ force: bool
+ labels: List[str]
+ pr_num: int
+ org: str
+ project: str
+ ignore_current: bool
+
+ has_trunk_label: bool
+ has_ciflow_label: bool
+
+ def __init__(
+ self,
+ force: bool,
+ labels: List[str],
+ pr_num: int,
+ org: str,
+ project: str,
+ ignore_current: bool,
+ ):
+ self.force = force
+ self.labels = labels
+ self.pr_num = pr_num
+ self.org = org
+ self.project = project
+ self.ignore_current = ignore_current
+
+ def _get_flag_msg(
+ self,
+ ignore_current_checks: Optional[
+ List[Tuple[str, Optional[str], Optional[int]]]
+ ] = None,
+ ) -> str:
+ if self.force:
+ return (
+ "Your change will be merged immediately since you used the force (-f) flag, "
+ + "**bypassing any CI checks** (ETA: 1-5 minutes). "
+ + "Please use `-f` as last resort and instead consider `-i/--ignore-current` "
+ + "to continue the merge ignoring current failures. This will allow "
+ + "currently pending tests to finish and report signal before the merge."
+ )
+ elif self.ignore_current and ignore_current_checks is not None:
+ msg = f"Your change will be merged while ignoring the following {len(ignore_current_checks)} checks: "
+ msg += ", ".join(f"[{x[0]}]({x[1]})" for x in ignore_current_checks)
+ return msg
+ else:
+ return "Your change will be merged once all checks pass (ETA 0-4 Hours)."
+
+ def get_merge_message(
+ self,
+ ignore_current_checks: Optional[
+ List[Tuple[str, Optional[str], Optional[int]]]
+ ] = None,
+ ) -> str:
+ title = "### Merge started"
+ main_message = self._get_flag_msg(ignore_current_checks)
+
+ advanced_debugging = "\n".join(
+ (
+ "Advanced Debugging
",
+ "Check the merge workflow status ",
+ f"here",
+ " ",
+ )
+ )
+
+ msg = title + "\n"
+ msg += main_message + "\n\n"
+ msg += ALTERNATIVES + "\n\n"
+ msg += CONTACT_US
+ msg += advanced_debugging
+ return msg
+
+
+def get_revert_message(org: str, project: str, pr_num: int) -> str:
+ msg = (
+ "@pytorchbot successfully started a revert job."
+ + f" Check the current status [here]({os.getenv('GH_RUN_URL')}).\n"
+ )
+ msg += CONTACT_US
+ return msg
diff --git a/.github/workflows/apple.yml b/.github/workflows/apple.yml
index bd4767e558..04a369ec72 100644
--- a/.github/workflows/apple.yml
+++ b/.github/workflows/apple.yml
@@ -89,15 +89,24 @@ jobs:
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
backends/apple/mps/install_requirements.sh
- # Build iOS Frameworks
+ # Build Release iOS Frameworks
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
build/build_apple_frameworks.sh --coreml --mps --optimized --portable --quantized --xnnpack
- # Bundle iOS Frameworks
+ # Bundle Release iOS Frameworks
for FRAMEWORK in "${FRAMEWORKS[@]}"; do (
cd cmake-out && zip -r "${RUNNER_TEMP}/artifacts/${FRAMEWORK}-${VERSION}.zip" "${FRAMEWORK}.xcframework"
) done
+ # Build Debug iOS Frameworks
+ PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
+ build/build_apple_frameworks.sh --coreml --mps --optimized --portable --quantized --xnnpack --Debug
+
+ # Bundle Debug iOS Frameworks
+ for FRAMEWORK in "${FRAMEWORKS[@]}"; do (
+ cd cmake-out && zip -r "${RUNNER_TEMP}/artifacts/${FRAMEWORK}_debug-${VERSION}.zip" "${FRAMEWORK}_debug.xcframework"
+ ) done
+
popd
upload-frameworks-ios:
diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml
index 8382df3ec4..b33c0a0ca4 100644
--- a/.github/workflows/cherry-pick.yml
+++ b/.github/workflows/cherry-pick.yml
@@ -46,7 +46,7 @@ jobs:
run: |
set -ex
- python ./third-party/pytorch/.github/scripts/cherry_pick.py \
+ python .github/scripts/cherry_pick.py \
--onto-branch "${BRANCH}" \
--classification "${CLASSIFICATION}" \
--fixes "${FIXES}" \
diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml
index 32a856ea0a..304f24529f 100644
--- a/.github/workflows/pull.yml
+++ b/.github/workflows/pull.yml
@@ -90,6 +90,7 @@ jobs:
matrix:
dtype: [fp32]
build-tool: [buck2, cmake]
+ mode: [portable, xnnpack]
fail-fast: false
with:
runner: linux.2xlarge
@@ -104,13 +105,14 @@ jobs:
DTYPE=${{ matrix.dtype }}
BUILD_TOOL=${{ matrix.build-tool }}
+ MODE=${{ matrix.mode }}
# Setup executorch
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh buck2
# Install requirements for export_llama
PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh
# Test llama2
- PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}"
+ PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}" "${MODE}"
test-custom-ops-linux:
name: test-custom-ops-linux
diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml
index 826b8ab45d..2558333554 100644
--- a/.github/workflows/trunk.yml
+++ b/.github/workflows/trunk.yml
@@ -215,3 +215,63 @@ jobs:
# Build and test coreml delegate
PYTHON_EXECUTABLE=python ${CONDA_RUN} bash backends/apple/coreml/scripts/build_all.sh
popd
+
+ test-pybind-build-macos:
+ name: test-pybind-build-macos
+ uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
+ strategy:
+ matrix:
+ include:
+ - build-tool: cmake
+ fail-fast: false
+ with:
+ runner: macos-m1-stable
+ python-version: '3.11'
+ submodules: 'true'
+ ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
+ timeout: 180
+ script: |
+ WORKSPACE=$(pwd)
+ pushd "${WORKSPACE}/pytorch/executorch"
+ bash .ci/scripts/setup-conda.sh
+
+ # build module for executorch.extension.pybindings.portable_lib
+ BUILD_TOOL=${{ matrix.build-tool }}
+ EXECUTORCH_BUILD_PYBIND=ON PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/setup-macos.sh "${BUILD_TOOL}"
+
+ # see if we can import the module successfully
+ ${CONDA_RUN} python -c "from executorch.extension.pybindings import portable_lib; print('success!')"
+ popd
+
+ test-llama-runner-macos:
+ name: test-llama-runner-mac
+ uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
+ strategy:
+ matrix:
+ dtype: [fp32]
+ build-tool: [buck2, cmake]
+ mode: [portable, xnnpack]
+ fail-fast: false
+ with:
+ runner: macos-m1-stable
+ python-version: '3.11'
+ submodules: 'true'
+ ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
+ timeout: 900
+ script: |
+ WORKSPACE=$(pwd)
+ pushd "${WORKSPACE}/pytorch/executorch"
+ bash .ci/scripts/setup-conda.sh
+
+ DTYPE=${{ matrix.dtype }}
+ BUILD_TOOL=${{ matrix.build-tool }}
+ MODE=${{ matrix.mode }}
+
+ # Setup executorch
+ PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/setup-macos.sh "${BUILD_TOOL}"
+
+ # Install requirements for export_llama
+ PYTHON_EXECUTABLE=python ${CONDA_RUN} bash examples/models/llama2/install_requirements.sh
+ # Test llama2
+ PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}" "${MODE}"
+ popd
diff --git a/.lintrunner.toml b/.lintrunner.toml
index ab31cd9d34..a00b04e83e 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -6,6 +6,7 @@ include_patterns = ['**/*.py']
exclude_patterns = [
'third-party/**',
'**/third-party/**',
+ '.github/scripts/**',
]
command = [
'python',
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3680c5a44b..432a809d35 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,15 +13,6 @@
# cloning or pulling the upstream repo. Once this is done, you don't need to do
# it again until you pull from the upstream repo again.
#
-# NOTE: If your `buck2` binary is not on the PATH, you can change this line to
-# say something like `-DBUCK2=/tmp/buck2` to point directly to the tool.
-#[[
- (rm -rf cmake-out \
- && mkdir cmake-out \
- && cd cmake-out \
- && cmake -DBUCK2=buck2 ..)
-]]
-#
# ### Build ###
#
# NOTE: The `-j` argument specifies how many jobs/processes to use when
@@ -140,8 +131,6 @@ option(EXECUTORCH_BUILD_ARM_BAREMETAL
option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)
-option(EXECUTORCH_BUILD_EXTENSION_AOT_UTIL "Build the AOT util library" OFF)
-
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
OFF)
@@ -171,8 +160,9 @@ option(EXECUTORCH_BUILD_XNNPACK "Build the XNNPACK backend" OFF)
option(EXECUTORCH_BUILD_VULKAN "Build the Vulkan backend" OFF)
if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
+ resolve_python_executable()
endif()
+message(STATUS "Using python executable '${PYTHON_EXECUTABLE}'")
# TODO(dbort): Fix these warnings and remove this flag.
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
@@ -372,10 +362,6 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
target_compile_options(executor_runner PUBLIC ${_common_compile_options})
endif()
-if(EXECUTORCH_BUILD_EXTENSION_AOT_UTIL)
- add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/aot_util)
- endif()
-
# Add googletest if any test targets should be built
if(EXECUTORCH_BUILD_GTESTS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/googletest)
@@ -459,7 +445,7 @@ if(EXECUTORCH_BUILD_PYBIND)
# find pytorch lib, to allow pybind to take at::Tensor as input/output
find_package(Torch CONFIG REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python
- PATHS "${TORCH_INSTALL_PREFIX}/lib")
+ PATHS "${TORCH_INSTALL_PREFIX}/lib")
# compile options for pybind
diff --git a/Package.swift b/Package.swift
index 3755c439b7..8dda54be47 100644
--- a/Package.swift
+++ b/Package.swift
@@ -11,13 +11,13 @@ import PackageDescription
let url = "https://ossci-ios.s3.amazonaws.com/executorch"
let version = "0.1.0"
-let coreml_sha256 = "a08d3a06f65c6c124214b27de77057832452206625cde36261b4b6a346314802"
-let executorch_sha256 = "ee0c1b870036834f7ac0dbf99fa396990243a96e0939c7d4f0ea341b794dcc38"
-let mps_sha256 = "020fedd9f7670422c132da42ddf3b9307c67f12f85c6928109f1d4885c67b1ca"
-let optimized_sha256 = "e5f3d9814758d79da7547c1936e7a665e305a82e4d6f340e25e41b6b924e45d1"
-let portable_sha256 = "968a8aa09794b69d60c9cfb6c9cfc37c8842a51fd0cafa14f7b7daa4d8e80eea"
-let quantized_sha256 = "e46e4252f5d0f134bf2edbf559ad07c92c49288dfcab21fa7406e1424051de1f"
-let xnnpack_sha256 = "016d4b3f947c267d9ffd4884198730a0f5a5a606d3376addd96e45aaa7a366cc"
+let coreml_sha256 = "e8c5000a389bdc98274aa0b359350a47e6d0cccb8af5efc46f814feac6afaf86"
+let executorch_sha256 = "e6c5d798b614a03ab8a4891caeaa8a7adf8d58ba29e767079321691ec9f1ffb4"
+let mps_sha256 = "3e54e3166b5e739cb3f76b2bc6f7b1982a0401821ab785a93120bacfde4bc1ee"
+let optimized_sha256 = "4d353f44badd321cf29fe548db9d66b493b93c6233a7e023988e256f0eefeaa1"
+let portable_sha256 = "c501f9b644a3e8a7bab62600b7802e4a9752fb789ba4fd02f46bec47858cec07"
+let quantized_sha256 = "4fb5f7216abc0ee16ece91a4bce822b06d67b52ca985c9eecbf9d3f8bd1ea1ba"
+let xnnpack_sha256 = "e610904cfd6e96f8f738c25a7bb4f6d7b86995b2cfeb72fc1f30523630dbb285"
struct Framework {
let name: String
diff --git a/backends/apple/mps/CMakeLists.txt b/backends/apple/mps/CMakeLists.txt
index 0a82f38498..ef64e26f2c 100644
--- a/backends/apple/mps/CMakeLists.txt
+++ b/backends/apple/mps/CMakeLists.txt
@@ -11,15 +11,17 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
if(NOT FLATC_EXECUTABLE)
set(FLATC_EXECUTABLE flatc)
endif()
diff --git a/backends/arm/third-party/serialization_lib b/backends/arm/third-party/serialization_lib
index bd8c5295be..187af0d41f 160000
--- a/backends/arm/third-party/serialization_lib
+++ b/backends/arm/third-party/serialization_lib
@@ -1 +1 @@
-Subproject commit bd8c5295bef8fced1ad11323cc7204c620ccd1fb
+Subproject commit 187af0d41fe75d08d2a7ec84c1b4d24b9b641ed2
diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py
index e5ce8ec2d7..a362c7dd8f 100644
--- a/backends/qualcomm/builders/node_visitor.py
+++ b/backends/qualcomm/builders/node_visitor.py
@@ -396,7 +396,8 @@ def register_node_visitor(visitor):
and issubclass(visitor, NodeVisitor)
and hasattr(visitor, "target")
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
- _node_visitor_dict[visitor.target] = visitor
+ for target in visitor.target:
+ _node_visitor_dict[target] = visitor
def generate_node_to_external_map(
diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py
index f151ca6698..ce61db2d6a 100644
--- a/backends/qualcomm/builders/op_add.py
+++ b/backends/qualcomm/builders/op_add.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Add(NodeVisitor):
- target = "aten.add.Tensor"
+ target = ["aten.add.Tensor"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py
index 38c3bd6d47..e6a6fd3f1c 100644
--- a/backends/qualcomm/builders/op_avg_pool2d.py
+++ b/backends/qualcomm/builders/op_avg_pool2d.py
@@ -16,7 +16,7 @@
@register_node_visitor
class AvgPool2d(NodeVisitor):
- target = "aten.avg_pool2d.default"
+ target = ["aten.avg_pool2d.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py
index 280cc86d7b..a0efda4072 100644
--- a/backends/qualcomm/builders/op_batch_norm.py
+++ b/backends/qualcomm/builders/op_batch_norm.py
@@ -16,7 +16,7 @@
@register_node_visitor
class BatchNorm(NodeVisitor):
- target = "aten._native_batch_norm_legit_no_training.default"
+ target = ["aten._native_batch_norm_legit_no_training.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py
index 4648321a6b..c207d73ad7 100644
--- a/backends/qualcomm/builders/op_bmm.py
+++ b/backends/qualcomm/builders/op_bmm.py
@@ -15,7 +15,7 @@
@register_node_visitor
class BMM(NodeVisitor):
- target = "aten.bmm.default"
+ target = ["aten.bmm.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_cast.py b/backends/qualcomm/builders/op_cast.py
index 18666e5544..d8173126c1 100644
--- a/backends/qualcomm/builders/op_cast.py
+++ b/backends/qualcomm/builders/op_cast.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Cast(NodeVisitor):
- target = "aten._to_copy.default"
+ target = ["aten._to_copy.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py
index 4cbfd6b542..bd7e8153c7 100644
--- a/backends/qualcomm/builders/op_cat.py
+++ b/backends/qualcomm/builders/op_cat.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Cat(NodeVisitor):
- target = "aten.cat.default"
+ target = ["aten.cat.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py
index 00ce561440..c486669fc8 100644
--- a/backends/qualcomm/builders/op_ceil.py
+++ b/backends/qualcomm/builders/op_ceil.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Ceil(NodeVisitor):
- target = "aten.ceil.default"
+ target = ["aten.ceil.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py
index 9417f726d5..24f2e01964 100644
--- a/backends/qualcomm/builders/op_clamp.py
+++ b/backends/qualcomm/builders/op_clamp.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Clamp(NodeVisitor):
- target = "aten.clamp.default"
+ target = ["aten.clamp.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py
index f899e98efd..5c20b1372e 100644
--- a/backends/qualcomm/builders/op_conv2d.py
+++ b/backends/qualcomm/builders/op_conv2d.py
@@ -24,7 +24,7 @@
@register_node_visitor
class Conv2d(NodeVisitor):
- target = "aten.convolution.default"
+ target = ["aten.convolution.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py
index 41e141cbfa..8624b6eb07 100644
--- a/backends/qualcomm/builders/op_depth_to_space.py
+++ b/backends/qualcomm/builders/op_depth_to_space.py
@@ -17,7 +17,7 @@
@register_node_visitor
class DepthToSpaceVisitor(NodeVisitor):
- target = "aten.pixel_shuffle.default"
+ target = ["aten.pixel_shuffle.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py
index 0574a4e2e2..56eb609575 100644
--- a/backends/qualcomm/builders/op_dequantize.py
+++ b/backends/qualcomm/builders/op_dequantize.py
@@ -55,19 +55,19 @@ def define_node(
@register_node_visitor
class PerTensorDequantizeDefault(DequantizeOpBase):
- target = "quantized_decomposed.dequantize_per_tensor.default"
+ target = ["quantized_decomposed.dequantize_per_tensor.default"]
@register_node_visitor
class PerTensorDequantizeTensor(DequantizeOpBase):
- target = "quantized_decomposed.dequantize_per_tensor.tensor"
+ target = ["quantized_decomposed.dequantize_per_tensor.tensor"]
@register_node_visitor
class PerChannelDequantizeDefault(DequantizeOpBase):
- target = "quantized_decomposed.dequantize_per_channel.default"
+ target = ["quantized_decomposed.dequantize_per_channel.default"]
@register_node_visitor
class PerChannelDequantizeTensor(DequantizeOpBase):
- target = "quantized_decomposed.dequantize_per_channel.tensor"
+ target = ["quantized_decomposed.dequantize_per_channel.tensor"]
diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py
index 4f0157bbdf..6b4e674349 100644
--- a/backends/qualcomm/builders/op_div.py
+++ b/backends/qualcomm/builders/op_div.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Div(NodeVisitor):
- target = "aten.div.Tensor"
+ target = ["aten.div.Tensor"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py
index 60d3a3906c..faf33eac12 100644
--- a/backends/qualcomm/builders/op_embedding.py
+++ b/backends/qualcomm/builders/op_embedding.py
@@ -17,7 +17,7 @@
@register_node_visitor
class Embedding(NodeVisitor):
- target = "aten.embedding.default"
+ target = ["aten.embedding.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py
index afef5e2269..a1ef1c2949 100644
--- a/backends/qualcomm/builders/op_expand.py
+++ b/backends/qualcomm/builders/op_expand.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Expand(NodeVisitor):
- target = "aten.expand_copy.default"
+ target = ["aten.expand_copy.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py
index 7dd627ce58..c488d6b5d8 100644
--- a/backends/qualcomm/builders/op_gelu.py
+++ b/backends/qualcomm/builders/op_gelu.py
@@ -16,7 +16,7 @@
@register_node_visitor
class GeluVisitor(NodeVisitor):
- target = "aten.gelu.default"
+ target = ["aten.gelu.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py
index 940bfc7d42..c7ad702ae6 100644
--- a/backends/qualcomm/builders/op_hardswish.py
+++ b/backends/qualcomm/builders/op_hardswish.py
@@ -16,7 +16,7 @@
@register_node_visitor
class HardSwishVisitor(NodeVisitor):
- target = "aten.hardswish.default"
+ target = ["aten.hardswish.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py
index 0f16d006da..d7d322cbdc 100644
--- a/backends/qualcomm/builders/op_hardtanh.py
+++ b/backends/qualcomm/builders/op_hardtanh.py
@@ -17,7 +17,7 @@
@register_node_visitor
class HardTanhVisitor(NodeVisitor):
- target = "aten.hardtanh.default"
+ target = ["aten.hardtanh.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py
index 53a30e434f..1f4b47672d 100644
--- a/backends/qualcomm/builders/op_layer_norm.py
+++ b/backends/qualcomm/builders/op_layer_norm.py
@@ -18,7 +18,7 @@
@register_node_visitor
class LayerNormVisitor(NodeVisitor):
- target = "aten.native_layer_norm.default"
+ target = ["aten.native_layer_norm.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py
index 907bda3d81..1e75df7b1a 100644
--- a/backends/qualcomm/builders/op_linear.py
+++ b/backends/qualcomm/builders/op_linear.py
@@ -17,7 +17,7 @@
@register_node_visitor
class LinearVisitor(NodeVisitor):
- target = "aten.linear.default"
+ target = ["aten.linear.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py
index a8259a5ca9..de01a14f89 100644
--- a/backends/qualcomm/builders/op_log_softmax.py
+++ b/backends/qualcomm/builders/op_log_softmax.py
@@ -16,7 +16,7 @@
@register_node_visitor
class LogSoftmax(NodeVisitor):
- target = "aten._log_softmax.default"
+ target = ["aten._log_softmax.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py
index 68540949b7..9a94cb1d60 100644
--- a/backends/qualcomm/builders/op_matmul.py
+++ b/backends/qualcomm/builders/op_matmul.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Matmul(NodeVisitor):
- target = "aten.matmul.default"
+ target = ["aten.matmul.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py
index 9b8076ba22..f64a13faed 100644
--- a/backends/qualcomm/builders/op_max_pool2d.py
+++ b/backends/qualcomm/builders/op_max_pool2d.py
@@ -16,7 +16,7 @@
@register_node_visitor
class MaxPool2d(NodeVisitor):
- target = "aten.max_pool2d_with_indices.default"
+ target = ["aten.max_pool2d_with_indices.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py
index 4d151eb9f2..29d9f8b30f 100644
--- a/backends/qualcomm/builders/op_mean_dim.py
+++ b/backends/qualcomm/builders/op_mean_dim.py
@@ -17,7 +17,7 @@
@register_node_visitor
class MeanDim(NodeVisitor):
- target = "aten.mean.dim"
+ target = ["aten.mean.dim"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py
index 943891a993..645910b7d9 100644
--- a/backends/qualcomm/builders/op_mul.py
+++ b/backends/qualcomm/builders/op_mul.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Mul(NodeVisitor):
- target = "aten.mul.Tensor"
+ target = ["aten.mul.Tensor"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py
index bf3bbbcab8..677cf77bd2 100644
--- a/backends/qualcomm/builders/op_pad.py
+++ b/backends/qualcomm/builders/op_pad.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Pad(NodeVisitor):
- target = "aten.constant_pad_nd.default"
+ target = ["aten.constant_pad_nd.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py
index 14f4edd9f5..cae2e68161 100644
--- a/backends/qualcomm/builders/op_pow.py
+++ b/backends/qualcomm/builders/op_pow.py
@@ -17,7 +17,7 @@
# TODO Add more class Like PowTensorTensor if needed
@register_node_visitor
class PowTensorScalar(NodeVisitor):
- target = "aten.pow.Tensor_Scalar"
+ target = ["aten.pow.Tensor_Scalar"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py
index e1d491cadb..b74ca7fb6d 100644
--- a/backends/qualcomm/builders/op_quantize.py
+++ b/backends/qualcomm/builders/op_quantize.py
@@ -61,9 +61,9 @@ def define_node(
@register_node_visitor
class PerTensorQuantize(QuantizeOpBase):
- target = "quantized_decomposed.quantize_per_tensor.default"
+ target = ["quantized_decomposed.quantize_per_tensor.default"]
@register_node_visitor
class PerChannelQuantize(QuantizeOpBase):
- target = "quantized_decomposed.quantize_per_channel.default"
+ target = ["quantized_decomposed.quantize_per_channel.default"]
diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py
index 52cbd410ee..d6c5ff79bc 100644
--- a/backends/qualcomm/builders/op_relu.py
+++ b/backends/qualcomm/builders/op_relu.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Relu(NodeVisitor):
- target = "aten.relu.default"
+ target = ["aten.relu.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py
index 96278b0f80..23eb1ff59b 100644
--- a/backends/qualcomm/builders/op_reshape.py
+++ b/backends/qualcomm/builders/op_reshape.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Reshape(NodeVisitor):
- target = "aten.view_copy.default"
+ target = ["aten.view_copy.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py
index 5976cab67f..cf3e8c5e38 100644
--- a/backends/qualcomm/builders/op_rsqrt.py
+++ b/backends/qualcomm/builders/op_rsqrt.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Rsqrt(NodeVisitor):
- target = "aten.rsqrt.default"
+ target = ["aten.rsqrt.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py
index ea53467521..5d74d038f7 100644
--- a/backends/qualcomm/builders/op_select_copy.py
+++ b/backends/qualcomm/builders/op_select_copy.py
@@ -17,7 +17,7 @@
@register_node_visitor
class SelectCopy(NodeVisitor):
- target = "aten.select_copy.int"
+ target = ["aten.select_copy.int", "aten.select.int"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py
index b6eeb88935..3b7dd2abe2 100644
--- a/backends/qualcomm/builders/op_sigmoid.py
+++ b/backends/qualcomm/builders/op_sigmoid.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Sigmoid(NodeVisitor):
- target = "aten.sigmoid.default"
+ target = ["aten.sigmoid.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_skip_ops.py b/backends/qualcomm/builders/op_skip_ops.py
index f91ef70d44..9a1839f604 100644
--- a/backends/qualcomm/builders/op_skip_ops.py
+++ b/backends/qualcomm/builders/op_skip_ops.py
@@ -35,7 +35,7 @@ class OpGetItem(OpSkipOps):
do nothing if node is getitem
"""
- target = "getitem"
+ target = ["getitem"]
def define_node(
self,
diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py
index 5ed2b99cc0..6d121135e4 100644
--- a/backends/qualcomm/builders/op_slice_copy.py
+++ b/backends/qualcomm/builders/op_slice_copy.py
@@ -16,7 +16,7 @@
@register_node_visitor
class StrideSlice(NodeVisitor):
- target = "aten.slice_copy.Tensor"
+ target = ["aten.slice_copy.Tensor"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py
index 2a1abce3d5..031c0244f3 100644
--- a/backends/qualcomm/builders/op_softmax.py
+++ b/backends/qualcomm/builders/op_softmax.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Softmax(NodeVisitor):
- target = "aten._softmax.default"
+ target = ["aten._softmax.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py
index 43ef39fab7..b13643783c 100644
--- a/backends/qualcomm/builders/op_squeeze.py
+++ b/backends/qualcomm/builders/op_squeeze.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Squeeze(NodeVisitor):
- target = "aten.squeeze_copy.dims"
+ target = ["aten.squeeze_copy.dims", "aten.squeeze.dims"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py
index 212e7a75cd..131fecd4cf 100644
--- a/backends/qualcomm/builders/op_sub.py
+++ b/backends/qualcomm/builders/op_sub.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Sub(NodeVisitor):
- target = "aten.sub.Tensor"
+ target = ["aten.sub.Tensor"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py
index cff4f7e447..af37256046 100644
--- a/backends/qualcomm/builders/op_tanh.py
+++ b/backends/qualcomm/builders/op_tanh.py
@@ -16,7 +16,7 @@
@register_node_visitor
class Tanh(NodeVisitor):
- target = "aten.tanh.default"
+ target = ["aten.tanh.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py
index 161e8cef9d..7dc9352673 100644
--- a/backends/qualcomm/builders/op_transpose.py
+++ b/backends/qualcomm/builders/op_transpose.py
@@ -17,7 +17,7 @@
@register_node_visitor
class TransposeVisitor(NodeVisitor):
- target = "aten.permute_copy.default"
+ target = ["aten.permute_copy.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py
index 636dc94e84..1a94903291 100644
--- a/backends/qualcomm/builders/op_unsqueeze.py
+++ b/backends/qualcomm/builders/op_unsqueeze.py
@@ -15,7 +15,7 @@
@register_node_visitor
class Unsqueeze(NodeVisitor):
- target = "aten.unsqueeze_copy.default"
+ target = ["aten.unsqueeze_copy.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py
index f32f136aa1..b383693ead 100644
--- a/backends/qualcomm/builders/op_upsample_bilinear2d.py
+++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py
@@ -15,7 +15,7 @@
@register_node_visitor
class ResizeBilinear(NodeVisitor):
- target = "aten.upsample_bilinear2d.default"
+ target = ["aten.upsample_bilinear2d.default"]
def __init__(self, *args) -> None:
super().__init__(*args)
diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py
index ee7d6a7a3b..585711c378 100644
--- a/backends/qualcomm/tests/utils.py
+++ b/backends/qualcomm/tests/utils.py
@@ -201,7 +201,8 @@ def lower_module_and_test_output(
# Assert the backend name is qnn
self.assertEqual(
- len(exec_prog.program.execution_plan[0].delegates), expected_partitions
+ len(exec_prog.program.execution_plan[0].delegates),
+ expected_partitions,
)
for i in range(expected_partitions):
self.assertEqual(
diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index 7fa696efba..b5c5d4dfed 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-from typing import List, Tuple
+from typing import Callable, Dict, List, Tuple
import executorch.exir as exir
@@ -19,8 +19,6 @@
ConvertBinaryOpsWithScalar,
)
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
-from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
-from executorch.backends.qualcomm.passes.convert_hardswish import ConvertHardswish
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
ConvertInterpolateWithUpsample2D,
)
@@ -29,9 +27,6 @@
from executorch.backends.qualcomm.passes.i64_to_i32 import I64toI32
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
-from executorch.backends.qualcomm.passes.recompose_pixel_shuffle import (
- RecomposePixelShuffle,
-)
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
_soc_info_table,
@@ -49,7 +44,9 @@
convert_to_flatbuffer,
convert_to_option,
)
+from executorch.exir import ExirExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
+from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
from torch.export.exported_program import ExportedProgram
from torch.fx import passes
@@ -86,16 +83,27 @@ def canonicalize_program(prog: ExportedProgram):
)
+def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
+ source_decompositions = torch_core_aten_decompositions()
+ # The below super ops are supported by QNN
+ remove_decompositions = [
+ torch.ops.aten.pixel_shuffle.default,
+ torch.ops.aten.hardswish.default,
+ ]
+
+ for key in remove_decompositions:
+ source_decompositions.pop(key)
+
+ return source_decompositions
+
+
def _transform(edge_program: ExportedProgram) -> None:
# currently ExirExportedProgram.transform does not accept
# changes of input number which was caused by FoldQDQ
# apply passes one by one here to avoid IR capture failure
graph_module = edge_program.graph_module
RemoveClone()(graph_module)
- RecomposePixelShuffle()(graph_module)
ConvertToLinear()(graph_module)
- ConvertHardsigmoid()(graph_module)
- ConvertHardswish()(graph_module)
ConvertBmmToMatmul()(graph_module)
ConvertInterpolateWithUpsample2D()(graph_module)
I64toI32(edge_program)(graph_module)
@@ -111,19 +119,18 @@ def capture_program(
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
) -> exir.ExirExportedProgram:
- # TODO: should switch to torch.export.export & custom deomposition
- # to reduce maintaining effort.
- exir_exported_program = exir.capture(
- module,
- inputs,
- qnn_capture_config(),
- )
+ ep = torch.export.export(module, inputs)
+ decomposed_ep = ep.run_decompositions(get_decomp_table())
+
# We choose call_operator by target in ConvertBinaryOpsWithScalar
# because it is the same source_fn_stack for MultiheadAttention
- exir_exported_program.transform(ConvertBinaryOpsWithScalar())
- ex_prog = exir_exported_program.to_edge(qnn_edge_config())
- _transform(ex_prog.exported_program)
- return ex_prog
+ # TODO: Should modify the scalar op in the op builder instead of
+ # using transformation
+ core_ep = ExirExportedProgram(decomposed_ep, False)
+ core_ep.transform(ConvertBinaryOpsWithScalar())
+ edge_ep = core_ep.to_edge(qnn_edge_config())
+ _transform(edge_ep.exported_program)
+ return edge_ep
def draw_graph(title, path, graph_module: torch.fx.GraphModule):
diff --git a/backends/vulkan/CMakeLists.txt b/backends/vulkan/CMakeLists.txt
index 605d1a5029..cee4b9da56 100644
--- a/backends/vulkan/CMakeLists.txt
+++ b/backends/vulkan/CMakeLists.txt
@@ -20,12 +20,14 @@ if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
if(NOT RUNTIME_PATH)
set(RUNTIME_PATH ${CMAKE_CURRENT_SOURCE_DIR}/runtime)
endif()
if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
+ resolve_python_executable()
endif()
if(NOT FLATC_EXECUTABLE)
diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
index acb672ae1b..c648db2c4c 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
+++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
@@ -36,7 +36,12 @@ layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
}
other_sizes;
-layout(set = 0, binding = 6) uniform PRECISION restrict Alpha {
+layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams {
+ ivec2 data;
+}
+broadcast_params;
+
+layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
float data;
}
alpha;
@@ -63,8 +68,11 @@ void main() {
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
0));
- // Detect broadcasting
- if (PACKED_DIM_${PACKING}(other_sizes.data) < PACKED_DIM_${PACKING}(in_sizes.data)) {
+ // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
+ if (broadcast_params.data.x > 0) {
+ in_texel = in_texel.xxxx;
+ }
+ if (broadcast_params.data.y > 0) {
other_texel = other_texel.xxxx;
}
diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
index b8adb75c4f..fef2802f2d 100644
--- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
@@ -24,7 +24,6 @@ void check_binary_op_args(
const vTensor& other,
const vTensor& out) {
VK_CHECK_COND(check_same_memory_layout(self, other, out));
- VK_CHECK_COND(check_broadcastable(self, other));
std::vector broadcasted_sizes =
calculate_broadcasted_output_size(self, other);
VK_CHECK_COND(out.sizes() == broadcasted_sizes);
@@ -36,6 +35,8 @@ void resize_binary_op_node(
const std::vector& extra_args) {
(void)extra_args;
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
+
+ // TODO(T183442143): Verify tensors are broadcastable.
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();
@@ -73,6 +74,9 @@ void add_binary_op_node(
alpha_val = extract_scalar(graph.get_val(alpha));
}
+ const api::utils::ivec2 broadcast_params =
+ create_broadcast_params(t_in1, t_in2);
+
std::stringstream kernel_name;
kernel_name << "binary_" << op_name;
apply_memory_layout_suffix(kernel_name, t_out);
@@ -90,6 +94,7 @@ void add_binary_op_node(
{t_out.gpu_sizes_ubo(),
t_in1.gpu_sizes_ubo(),
t_in2.gpu_sizes_ubo(),
+ graph.create_params_buffer(broadcast_params),
graph.create_params_buffer(alpha_val)},
// Resizing
resize_binary_op_node));
diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp
index e8f8e7b0bd..96544743f4 100644
--- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp
@@ -72,28 +72,35 @@ bool check_same_memory_layout(
return (t1.gpu_memory_layout() == t3.gpu_memory_layout());
}
-bool check_broadcastable(const vTensor& t1, const vTensor& t2) {
- size_t ndim = std::max(t1.sizes().size(), t2.sizes().size());
+//
+// Broadcast flag functions
+//
- // Match the sizes in reverse because sizes are in NCHW order
- for (int i = -1; i >= -ndim; --i) {
- int64_t t1_size = api::utils::val_at(i, t1.sizes());
- int64_t t2_size = api::utils::val_at(i, t2.sizes());
- // If the sizes are not equal, one of them must be 1
- if (t1_size != t2_size) {
- if (t1_size > 1 && t2_size != 1) {
- return false;
- } else if (t2_size > 1 && t1_size != 1) {
- return false;
- }
- }
+bool is_packed_dim_broadcasted(const vTensor& sndr, const vTensor& rcvr) {
+ // We assume that the tensors are broadcastable. If values aren't equal at
+ // some index, then the value of rcvr is 1 and hence should be broadcasted.
+ switch (sndr.gpu_memory_layout()) {
+ case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
+ return api::utils::val_at(-3, sndr.sizes()) >
+ api::utils::val_at(-3, rcvr.sizes());
+ case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
+ return api::utils::val_at(-2, sndr.sizes()) >
+ api::utils::val_at(-2, rcvr.sizes());
+ case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
+ return api::utils::val_at(-1, sndr.sizes()) >
+ api::utils::val_at(-1, rcvr.sizes());
}
+}
- return true;
+api::utils::ivec2 create_broadcast_params(
+ const vTensor& t1,
+ const vTensor& t2) {
+ return api::utils::make_ivec2(
+ {is_packed_dim_broadcasted(t2, t1), is_packed_dim_broadcasted(t1, t2)});
}
//
-// Work Group Size Calculation Utilities
+// Work group size calculation functions
//
api::utils::uvec3 adaptive_work_group_size(
diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h
index aa206fb1cf..31cef9f18a 100644
--- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h
+++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h
@@ -47,10 +47,14 @@ bool check_same_memory_layout(
const vTensor& t2,
const vTensor& t3);
-bool check_broadcastable(const vTensor& t1, const vTensor& t2);
+//
+// Broadcast flag functions
+//
+
+api::utils::ivec2 create_broadcast_params(const vTensor& t1, const vTensor& t2);
//
-// Work Group Size Calculation Utilities
+// Work group size calculation functions
//
api::utils::uvec3 adaptive_work_group_size(
diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py
index f5c838314d..8481fdf945 100644
--- a/backends/vulkan/test/test_vulkan_delegate.py
+++ b/backends/vulkan/test/test_vulkan_delegate.py
@@ -146,16 +146,20 @@ class AddModule(torch.nn.Module):
def __init__(self):
super().__init__()
- def forward(self, x, y):
+ def forward(self, x, y, w):
z = x + y
z = z + x
z = z + x
+ z = z + w
+ z = w + z
+ z = z + 3 # test scalar broadcasting
return z
add_module = AddModule()
sample_inputs = (
torch.rand(size=(2, 3), dtype=torch.float32),
torch.rand(size=(2, 3), dtype=torch.float32),
+ torch.rand(size=(2, 1), dtype=torch.float32), # test broadcasting
)
self.lower_module_and_test_output(add_module, sample_inputs)
diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp
index 2b7766665c..e2810d7c06 100644
--- a/backends/vulkan/test/vulkan_compute_api_test.cpp
+++ b/backends/vulkan/test/vulkan_compute_api_test.cpp
@@ -549,7 +549,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
std::vector size_big = {12, 64, 64};
std::vector size_small = {12, 64, 64};
- // Build graph
+ // Build graph and regularly check allocation counts
IOValueRef a = graph.add_input_tensor(
size_big,
@@ -560,9 +560,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
api::kFloat,
/*shared_object_idx = */ 4);
- // Allocation count will be 6:
- // 4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
- // 2: staging buffer for each input tensor
+ // +4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader
+ // +2: staging buffer for each input tensor
EXPECT_TRUE(get_vma_allocation_count() == 6);
ValueRef c = graph.add_tensor(
@@ -578,11 +577,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
api::kFloat,
/*shared_object_idx = */ 2);
- // Allocation count will be 11, 5 are new:
- // 2: out.gpu_sizes_ubo(), alpha UBO for arithmetic shader
- // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
- // 1: staging buffer for the input tensor
- EXPECT_TRUE(get_vma_allocation_count() == 11);
+ // +3: out.gpu_sizes_ubo(), alpha UBO, broadcast UBO for arithmetic shader
+ // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader
+ // +1: staging buffer for the input tensor
+ EXPECT_TRUE(get_vma_allocation_count() == 12);
ValueRef e = graph.add_tensor(
size_big,
@@ -596,18 +594,16 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
out.value = e;
out.staging = graph.set_output_tensor(out.value);
- // Allocation count will be 15, 4 are new:
- // 1: alpha UBO for arithmetic shader
- // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
- // 1 staging buffer for the input tensor
- EXPECT_TRUE(get_vma_allocation_count() == 15);
+ // +2: alpha UBO, broadcast UBO for arithmetic shader
+ // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader
+ // +1 staging buffer for the input tensor
+ EXPECT_TRUE(get_vma_allocation_count() == 17);
graph.prepare();
graph.encode_execute();
- // Allocation count will be 18, 3 are new:
- // 3: shared memory allocations for tensors
- EXPECT_TRUE(get_vma_allocation_count() == 18);
+ // +3: shared memory allocations for tensors
+ EXPECT_TRUE(get_vma_allocation_count() == 20);
// Run graph
diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt
index ea1ed0b73f..95aea7f93c 100644
--- a/backends/xnnpack/CMakeLists.txt
+++ b/backends/xnnpack/CMakeLists.txt
@@ -17,10 +17,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
if(NOT FLATC_EXECUTABLE)
set(FLATC_EXECUTABLE flatc)
endif()
@@ -32,6 +28,10 @@ endif()
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
set(_common_compile_options -Wno-deprecated-declarations)
diff --git a/backends/xnnpack/operators/TARGETS b/backends/xnnpack/operators/TARGETS
index b9f2998c8f..1fd4b5c475 100644
--- a/backends/xnnpack/operators/TARGETS
+++ b/backends/xnnpack/operators/TARGETS
@@ -13,6 +13,5 @@ runtime.python_library(
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/exir:graph_module",
"//executorch/exir/backend:backend_details",
- "//executorch/extension/aot_util:aot_util",
],
)
diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py
index f70afa754d..0b0eb7912a 100644
--- a/backends/xnnpack/operators/node_visitor.py
+++ b/backends/xnnpack/operators/node_visitor.py
@@ -5,9 +5,7 @@
# LICENSE file in the root directory of this source tree.
import ctypes
-import sys
-from pathlib import Path
from typing import cast, Dict, List, Optional, Tuple
import torch
@@ -449,18 +447,6 @@ def define_tensor(
if quant_params is not None:
vals_to_ids[quant_params.q_input] = id_out
- @staticmethod
- def find_aot_util_path() -> str:
- # Look for .so installed by wheel (OSS). TODO(gjcomer) Improve this.
- rel_path = "executorch/extension/aot_util/libaot_util.so"
- for sys_path in sys.path:
- so_path = Path(sys_path) / rel_path
- if so_path.exists():
- return str(so_path.absolute().as_posix())
-
- # Fall back to buck.
- return "//executorch/extension/aot_util:aot_util"
-
@staticmethod
def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
"""
@@ -478,37 +464,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
# Assuming we have a 2d tensor
if inp.ndim != 2:
inp = inp.squeeze()
- assert (
- inp.ndim == 2
- ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
- oc, ic = inp.shape
+ assert (
+ inp.ndim == 2
+ ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
# pad ic
- if ic % 2 != 0:
+ if inp.shape[-1] % 2 != 0:
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
+ # Shape after padding
+ oc, ic = inp.shape
+ assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
+
# Adjust inp tensor for zp
inp = inp.to(dtype=torch.uint8) + 8
- # prepare result tensor
- ric = int((ic + 1) / 2)
- result = torch.zeros([oc, ric], dtype=torch.uint8)
-
- try:
- aot_path = NodeVisitor.find_aot_util_path()
- torch.ops.load_library(aot_path)
- result = torch.ops.xnnpack.convert_to_qc4w(inp)
- except:
- # Fallback to python implementation
- # TODO Warn the user? They might be developing in-tree and didn't install,
- # in which case, this will be very slow for large models.
- for o in range(oc):
- for i in range(ric):
- j = 2 * i
- result[o][i] = inp[o][j]
- result[o][i] += inp[o][j + 1] << 4
-
- return result
+ # Prepare the Result tensor
+ inp = inp.contiguous().view(-1)
+ return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2))
def get_serialized_buffer_index(
self,
diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS
index 3b9950a7a1..ae16d366b6 100644
--- a/backends/xnnpack/test/TARGETS
+++ b/backends/xnnpack/test/TARGETS
@@ -75,15 +75,3 @@ runtime.python_test(
"//executorch/backends/xnnpack:xnnpack_preprocess",
],
)
-
-runtime.python_test(
- name = "test_custom_convert_qc4w_op",
- srcs = ["ops/test_custom_convert_to_qc4w.py"],
- deps = [
- "//caffe2:torch",
- "//executorch/extension/aot_util:aot_util",
- ],
- external_deps = [
- "libtorch",
- ],
-)
diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py
index 8883540e27..3a56e0f4c6 100644
--- a/backends/xnnpack/test/ops/add.py
+++ b/backends/xnnpack/test/ops/add.py
@@ -97,7 +97,6 @@ def test_qs8_add_constant(self):
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
.to_executorch()
.serialize()
- .dump_artifact("/data/users/maxren/models/q_add_constant.pte")
.run_method()
.compare_outputs()
)
diff --git a/backends/xnnpack/test/ops/test_custom_convert_to_qc4w.py b/backends/xnnpack/test/ops/test_custom_convert_to_qc4w.py
deleted file mode 100644
index c6e13ad345..0000000000
--- a/backends/xnnpack/test/ops/test_custom_convert_to_qc4w.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import unittest
-
-import torch
-
-
-class TestCustomQC4WConvert(unittest.TestCase):
- def setUp(self):
- torch.ops.load_library("//executorch/extension/aot_util:aot_util")
-
- def test_convert(self):
- def _ref_output(inp):
- oc, ic = inp.shape
- if ic % 2 != 0:
- raise ValueError("Number of input channels not divisible by 2.")
- ric = (ic + 1) // 2
- result = torch.zeros([oc, ric], dtype=torch.uint8)
- for o in range(oc):
- for i in range(ric):
- j = 2 * i
- result[o][i] = inp[o][j]
- result[o][i] += inp[o][j + 1] << 4
- return result
-
- inp = torch.randint(low=0, high=15, size=(20, 42), dtype=torch.uint8)
- result = torch.ops.xnnpack.convert_to_qc4w(inp)
- ref_result = _ref_output(inp)
- assert torch.equal(result, ref_result), "Outputs dont match"
-
- def test_convert_throws(self):
- inp = torch.randint(low=0, high=15, size=(20, 41), dtype=torch.uint8)
- exception_thrown = False
- # Because for some reason self.assertRaises does not work
- # and didnt try to debug
- try:
- torch.ops.xnnpack.convert_to_qc4w(inp)
- except:
- exception_thrown = True
- self.assertTrue(exception_thrown)
-
- inp = torch.rand((20, 41))
- exception_thrown = False
- # Because for some reason self.assertRaises does not work
- # and didnt try to debug
- try:
- torch.ops.xnnpack.convert_to_qc4w(inp)
- except:
- exception_thrown = True
- self.assertTrue(exception_thrown)
diff --git a/build/Utils.cmake b/build/Utils.cmake
index 10844b453c..3c89d2d577 100644
--- a/build/Utils.cmake
+++ b/build/Utils.cmake
@@ -46,8 +46,6 @@ function(executorch_print_configuration_summary)
" EXECUTORCH_BUILD_COREML : ${EXECUTORCH_BUILD_COREML}")
message(STATUS " EXECUTORCH_BUILD_EXECUTOR_RUNNER : "
"${EXECUTORCH_BUILD_EXECUTOR_RUNNER}")
- message(STATUS " EXECUTORCH_BUILD_EXTENSION_AOT_UTIL : "
- "${EXECUTORCH_BUILD_EXTENSION_AOT_UTIL}")
message(STATUS " EXECUTORCH_BUILD_EXTENSION_DATA_LOADER : "
"${EXECUTORCH_BUILD_EXTENSION_DATA_LOADER}")
message(STATUS " EXECUTORCH_BUILD_EXTENSION_MODULE : "
@@ -191,10 +189,32 @@ function(resolve_buck2)
if(resolve_buck2_exit_code EQUAL 0)
set(BUCK2 ${resolve_buck2_output} PARENT_SCOPE)
message(STATUS "Resolved buck2 as ${resolve_buck2_output}.")
- else()
+ elseif(resolve_buck2_exit_code EQUAL 2)
# Wrong buck version used. Stop here to ensure that the user sees
# the error.
- message(FATAL_ERROR "Failed to resolve buck2.")
- message(FATAL_ERROR ${resolve_buck2_error})
+ message(FATAL_ERROR "Failed to resolve buck2.\n${resolve_buck2_error}")
+ else()
+ # Unexpected failure of the script. Warn.
+ message(WARNING "Failed to resolve buck2.")
+ message(WARNING "${resolve_buck2_error}")
+
+ if("${BUCK2}" STREQUAL "")
+ set(BUCK2 "buck2" PARENT_SCOPE)
+ endif()
endif()
-endfunction()
\ No newline at end of file
+endfunction()
+
+# Sets the value of the PYTHON_EXECUTABLE variable to 'python' if in
+# an active (non-base) conda environment, and 'python3' otherwise. This
+# maintains backwards compatibility for non-conda users and avoids conda
+# users needing to explicitly set PYTHON_EXECUTABLE=python.
+function(resolve_python_executable)
+ # Counter-intuitively, CONDA_DEFAULT_ENV contains the name of the
+ # active environment.
+ if(DEFINED ENV{CONDA_DEFAULT_ENV} AND
+ NOT $ENV{CONDA_DEFAULT_ENV} STREQUAL "base")
+ set(PYTHON_EXECUTABLE python PARENT_SCOPE)
+ else()
+ set(PYTHON_EXECUTABLE python3 PARENT_SCOPE)
+ endif()
+endfunction()
diff --git a/build/build_apple_frameworks.sh b/build/build_apple_frameworks.sh
index 61c5e00e93..d40e231f14 100755
--- a/build/build_apple_frameworks.sh
+++ b/build/build_apple_frameworks.sh
@@ -151,7 +151,8 @@ mkdir -p "$HEADERS_PATH"
//extension/module: \
| rsync -av --files-from=- "$SOURCE_ROOT_DIR" "$HEADERS_PATH/executorch"
-cp "$SOURCE_ROOT_DIR/extension/apple/ExecuTorch/Exported/"{*.h,*.modulemap} "$HEADERS_PATH"
+cp "$SOURCE_ROOT_DIR/extension/apple/ExecuTorch/Exported/"*.h "$HEADERS_PATH/executorch"
+cp "$SOURCE_ROOT_DIR/extension/apple/ExecuTorch/Exported/"*.modulemap "$HEADERS_PATH"
echo "Creating frameworks"
diff --git a/build/cmake_deps.toml b/build/cmake_deps.toml
index 31342240c3..49e543f99b 100644
--- a/build/cmake_deps.toml
+++ b/build/cmake_deps.toml
@@ -89,17 +89,6 @@ filters = [
# ---------------------------------- core end ----------------------------------
# ---------------------------------- extension start ----------------------------------
-
-[targets.extension_aot_util]
-buck_targets = [
- "//extension/aot_util:aot_util",
-]
-filters = [
- ".cpp$",
-]
-deps = [
- "executorch",
-]
[targets.extension_data_loader]
buck_targets = [
"//extension/data_loader:buffer_data_loader",
diff --git a/build/resolve_buck.py b/build/resolve_buck.py
index 6cd816494d..cba151ab34 100644
--- a/build/resolve_buck.py
+++ b/build/resolve_buck.py
@@ -154,8 +154,10 @@ def resolve_buck2(args: argparse.Namespace) -> Union[str, int]:
)
# Return an error, since the build will fail later. This lets us
- # give the user a more useful error message.
- return -1
+ # give the user a more useful error message. Note that an exit
+ # code of 2 allows us to distinguish from an unexpected error,
+ # such as a failed import, which exits with 1.
+ return 2
else:
# Look for system buck2 and check version. Note that this can return
# None.
diff --git a/configurations/CMakeLists.txt b/configurations/CMakeLists.txt
index 57f7abe31d..212a310bd7 100644
--- a/configurations/CMakeLists.txt
+++ b/configurations/CMakeLists.txt
@@ -11,9 +11,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
@@ -23,6 +20,12 @@ if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
set(_common_compile_options -Wno-deprecated-declarations)
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index c9ec76538d..239319f7c2 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -25,9 +25,8 @@
import pytorch_sphinx_theme
+# To let us import ./custom_directives.py
sys.path.insert(0, os.path.abspath("."))
-sys.path.insert(0, os.path.abspath("../../.."))
-sys.path.insert(0, os.path.abspath("../.."))
# -- Project information -----------------------------------------------------
project = "ExecuTorch"
@@ -44,8 +43,6 @@
import os
import sys
-sys.path.insert(0, os.path.abspath("../../"))
-
extensions = [
"breathe",
"sphinx.ext.autodoc",
diff --git a/docs/source/export-to-executorch-api-reference.rst b/docs/source/export-to-executorch-api-reference.rst
index 241228fafa..2150ac7f8c 100644
--- a/docs/source/export-to-executorch-api-reference.rst
+++ b/docs/source/export-to-executorch-api-reference.rst
@@ -1,7 +1,7 @@
Export to ExecuTorch API Reference
----------------------------------
-.. automodule:: exir
+.. automodule:: executorch.exir
.. autofunction:: to_edge
.. autoclass:: EdgeProgramManager
@@ -10,7 +10,7 @@ Export to ExecuTorch API Reference
.. autoclass:: ExecutorchProgramManager
:members: methods, config_methods, exported_program, buffer, debug_handle_map, dump_executorch_program
-.. automodule:: exir.backend.backend_api
+.. automodule:: executorch.exir.backend.backend_api
.. autofunction:: to_backend
.. autoclass:: LoweredBackendModule
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 871f4aba87..adbda475aa 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -85,10 +85,7 @@ Topics in this section will help you get started with ExecuTorch.
:caption: Working with LLMs
:hidden:
- llm/introduction
- llm/mobile/index
- llm/desktop/index
- llm/advanced-flow/index
+ llm/getting-started
.. toctree::
:glob:
diff --git a/docs/source/llm/advanced-flow/advanced-flow.md b/docs/source/llm/advanced-flow/advanced-flow.md
deleted file mode 100644
index f8b596340f..0000000000
--- a/docs/source/llm/advanced-flow/advanced-flow.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Advanced Flows
-
-## Custom quantization
-
-## Bring GGUF to PyTorch ecosystem
-
-## TorchTune interoperability
diff --git a/docs/source/llm/advanced-flow/index.rst b/docs/source/llm/advanced-flow/index.rst
deleted file mode 100644
index 03e2f1f0f4..0000000000
--- a/docs/source/llm/advanced-flow/index.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-Enabling LLMs on Advanced Flow
-==============================
-
-This section will walk you through
-
-.. toctree::
- :maxdepth: 1
-
- advanced-flow
diff --git a/docs/source/llm/desktop/benchmarks.md b/docs/source/llm/desktop/benchmarks.md
deleted file mode 100644
index 9a827aca07..0000000000
--- a/docs/source/llm/desktop/benchmarks.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Enabling LLMs on Desktop
-
-## Local Llama on Desktop Benchmarks
-
-**Results**
-
-**Instructions**
diff --git a/docs/source/llm/desktop/index.rst b/docs/source/llm/desktop/index.rst
deleted file mode 100644
index c9078231d9..0000000000
--- a/docs/source/llm/desktop/index.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-Enabling LLMs on Desktop
-========================
-
-This section will walk you through
-
-.. toctree::
- :maxdepth: 1
-
- benchmarks
diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md
new file mode 100644
index 0000000000..2bd5d5e851
--- /dev/null
+++ b/docs/source/llm/getting-started.md
@@ -0,0 +1,34 @@
+# Getting Started with LLMs via ExecuTorch
+
+This section provides guidance on enabling Large Language Models (LLMs), starting with a simple example and gradually introducing new concepts to improve performance and productivity.
+
+## Prerequisites
+
+- To run this tutorial, you’ll first need to first [Set up your ExecuTorch environment](../getting-started-setup.md).
+
+- We highly suggest you to check out [LLama2 README](../../../examples/models/llama2/README.md) in our examples for end-to-end Llama2 mobile demo.
+
+
+## Simple “Hello World” LLM example
+
+Let's create a simple LLM app from scratch. TODO
+
+## Quantization
+
+Most LLMs are too large to fit into a mobile phone, making quantization necessary. In this example, we will demonstrate how to use the XNNPACKQuantizer to quantize the model and run it on a CPU. TODO
+
+## Use Mobile Acceleration
+
+One of the benefits of ExecuTorch is the ability to delegate to mobile accelerators. Now, we will show a few examples of how to easily take advantage of mobile accelerators. TODO
+
+## Debugging and Profiling
+
+It is sometimes necessary to profile and inspect the execution process. In this example, we will demonstrate how the ExecuTorch SDK can be used to identify which operations are being executed on which hardware. TODO
+
+## How to use custom kernels
+
+In some cases, it is necessary to write custom kernels or import them from another source in order to achieve the desired performance. In this example, we will demonstrate how to use the `kvcache_with_sdpa` kernel.
+
+## How to build Mobile Apps
+
+Here's how to finally build a mobile app on Android and iOS. TODO
diff --git a/docs/source/llm/introduction.md b/docs/source/llm/introduction.md
deleted file mode 100644
index a97310eb4a..0000000000
--- a/docs/source/llm/introduction.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Introduction
-
-## Current landscape of local LLMs
-
-## What is our offering?
-
-## Why and when should you use it?
diff --git a/docs/source/llm/mobile/benchmarks.md b/docs/source/llm/mobile/benchmarks.md
deleted file mode 100644
index e056d59d95..0000000000
--- a/docs/source/llm/mobile/benchmarks.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# Mobile Benchmarks for Local Llama
-
-## Results
-
-
-## Instructions
diff --git a/docs/source/llm/mobile/customization-examples.md b/docs/source/llm/mobile/customization-examples.md
deleted file mode 100644
index 1364304f8d..0000000000
--- a/docs/source/llm/mobile/customization-examples.md
+++ /dev/null
@@ -1,9 +0,0 @@
-# Customization examples
-
-## Custom tokenization
-
-## Custom sampler
-
-## Speculative decoding
-
-## Modify a mobile app to use a different LLM model
diff --git a/docs/source/llm/mobile/getting-started.md b/docs/source/llm/mobile/getting-started.md
deleted file mode 100644
index c3d0b11d6d..0000000000
--- a/docs/source/llm/mobile/getting-started.md
+++ /dev/null
@@ -1,16 +0,0 @@
-# Getting Started with LLMs via ExecuTorch
-
-
-## Simple “Hello World” example
-
-
-## Use Mobile Acceleration
-
-
-## Quantization via XNNPACKQuantizer
-
-
-## Debugging and Profiling
-
-
-## Build Mobile LLM chat App Examples
diff --git a/docs/source/llm/mobile/index.rst b/docs/source/llm/mobile/index.rst
deleted file mode 100644
index 8012c01f8c..0000000000
--- a/docs/source/llm/mobile/index.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-Enabling LLMs on Mobile
-=======================
-
-This section will walk you through
-
-.. toctree::
- :maxdepth: 1
-
- benchmarks
- getting-started
- customization-examples
- validating-other-models
diff --git a/docs/source/llm/mobile/validating-other-models.md b/docs/source/llm/mobile/validating-other-models.md
deleted file mode 100644
index 6097241777..0000000000
--- a/docs/source/llm/mobile/validating-other-models.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Validating other models
-
-## Exportability results
diff --git a/docs/source/sdk-bundled-io.md b/docs/source/sdk-bundled-io.md
index d324e14ce3..2ed256d2ae 100644
--- a/docs/source/sdk-bundled-io.md
+++ b/docs/source/sdk-bundled-io.md
@@ -26,7 +26,7 @@ In `BundledProgram`, we create two new classes, `MethodTestCase` and `MethodTest
:::{dropdown} `MethodTestCase`
```{eval-rst}
-.. autofunction:: bundled_program.config.MethodTestCase.__init__
+.. autofunction:: executorch.sdk.bundled_program.config.MethodTestCase.__init__
:noindex:
```
:::
@@ -34,7 +34,7 @@ In `BundledProgram`, we create two new classes, `MethodTestCase` and `MethodTest
:::{dropdown} `MethodTestSuite`
```{eval-rst}
-.. autofunction:: bundled_program.config.MethodTestSuite
+.. autofunction:: executorch.sdk.bundled_program.config.MethodTestSuite
:noindex:
```
:::
@@ -49,7 +49,7 @@ We provide `create_bundled_program` API under `executorch/sdk/bundled_program/co
:::{dropdown} `BundledProgram`
```{eval-rst}
-.. currentmodule:: bundled_program.core
+.. currentmodule:: executorch.sdk.bundled_program.core
.. autofunction:: create_bundled_program
:noindex:
```
@@ -66,13 +66,13 @@ To serialize `BundledProgram` to make runtime APIs use it, we provide two APIs,
:::{dropdown} Serialize and Deserialize
```{eval-rst}
-.. currentmodule:: bundled_program.serialize
+.. currentmodule:: executorch.sdk.bundled_program.serialize
.. autofunction:: serialize_from_bundled_program_to_flatbuffer
:noindex:
```
```{eval-rst}
-.. currentmodule:: bundled_program.serialize
+.. currentmodule:: executorch.sdk.bundled_program.serialize
.. autofunction:: deserialize_from_flatbuffer_to_bundled_program
:noindex:
```
diff --git a/docs/source/sdk-etrecord.rst b/docs/source/sdk-etrecord.rst
index 8ea6293fbd..e9eeb52b4f 100644
--- a/docs/source/sdk-etrecord.rst
+++ b/docs/source/sdk-etrecord.rst
@@ -31,7 +31,7 @@ they are interested in working with via our tooling.
.. warning::
Users should do a deepcopy of the output of to_edge() and pass in the deepcopy to the generate_etrecord API. This is needed because the subsequent call, to_executorch(), does an in-place mutation and will lose debug data in the process.
-.. currentmodule:: sdk.etrecord._etrecord
+.. currentmodule:: executorch.sdk.etrecord._etrecord
.. autofunction:: generate_etrecord
Using an ``ETRecord``
diff --git a/docs/source/sdk-inspector.rst b/docs/source/sdk-inspector.rst
index 0e85f0b5a9..23c529cb9d 100644
--- a/docs/source/sdk-inspector.rst
+++ b/docs/source/sdk-inspector.rst
@@ -26,7 +26,7 @@ Inspector Methods
Constructor
~~~~~~~~~~~
-.. autofunction:: sdk.Inspector.__init__
+.. autofunction:: executorch.sdk.Inspector.__init__
**Example Usage:**
@@ -39,13 +39,13 @@ Constructor
to_dataframe
~~~~~~~~~~~~~~~~
-.. autofunction:: sdk.Inspector.to_dataframe
+.. autofunction:: executorch.sdk.Inspector.to_dataframe
print_data_tabular
~~~~~~~~~~~~~~~~~~
-.. autofunction:: sdk.Inspector.print_data_tabular
+.. autofunction:: executorch.sdk.Inspector.print_data_tabular
.. _example-usage-1:
@@ -61,7 +61,7 @@ print_data_tabular
find_total_for_module
~~~~~~~~~~~~~~~~~~~~~
-.. autofunction:: sdk.Inspector.find_total_for_module
+.. autofunction:: executorch.sdk.Inspector.find_total_for_module
.. _example-usage-2:
@@ -79,7 +79,7 @@ find_total_for_module
get_exported_program
~~~~~~~~~~~~~~~~~~~~
-.. autofunction:: sdk.Inspector.get_exported_program
+.. autofunction:: executorch.sdk.Inspector.get_exported_program
.. _example-usage-3:
@@ -118,7 +118,7 @@ of an ``Inspector`` instance, for example:
inspector.event_blocks
-.. autoclass:: sdk.inspector.inspector.EventBlock
+.. autoclass:: executorch.sdk.inspector.EventBlock
``Event`` Class
~~~~~~~~~~~~~~~
@@ -126,7 +126,7 @@ of an ``Inspector`` instance, for example:
Access ``Event`` instances through the ``events`` attribute of an
``EventBlock`` instance.
-.. autoclass:: sdk.inspector.inspector.Event
+.. autoclass:: executorch.sdk.inspector.Event
**Example Usage:**
diff --git a/examples/apple/mps/CMakeLists.txt b/examples/apple/mps/CMakeLists.txt
index 2e5261edcd..89c2b141b0 100644
--- a/examples/apple/mps/CMakeLists.txt
+++ b/examples/apple/mps/CMakeLists.txt
@@ -18,10 +18,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
if(NOT FLATC_EXECUTABLE)
set(FLATC_EXECUTABLE flatc)
endif()
@@ -31,6 +27,12 @@ if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Source root directory for pytorch.
if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
diff --git a/examples/arm/CMakeLists.txt b/examples/arm/CMakeLists.txt
index 19fe848727..489c715d1b 100644
--- a/examples/arm/CMakeLists.txt
+++ b/examples/arm/CMakeLists.txt
@@ -16,14 +16,17 @@ project(arm_example)
option(EXECUTORCH_SELECT_OPS_LIST "Register the following list of ops" OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
-
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()
+
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Source root directory for pytorch.
if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
diff --git a/examples/demo-apps/android/LlamaDemo/README.md b/examples/demo-apps/android/LlamaDemo/README.md
index d12c26fb74..fcbe46c9b4 100644
--- a/examples/demo-apps/android/LlamaDemo/README.md
+++ b/examples/demo-apps/android/LlamaDemo/README.md
@@ -14,6 +14,8 @@ adb push model.pte /data/local/tmp/llama
adb push tokenizer.bin /data/local/tmp/llama
```
+The demo app searches in `/data/local/tmp/llama` for .pte and .bin files as LLAMA model and tokenizer.
+
## Build JNI library
1. Open a terminal window and navigate to the root directory of the `executorch`.
2. Set the following environment variables:
@@ -48,23 +50,13 @@ cmake .. -DBUCK2="$BUCK" \
cmake --build . -j50
popd
```
-6.
-Copy the built library to your app:
+6. Copy the built library to your app:
```
JNI_LIBS_PATH="examples/demo-apps/android/LlamaDemo/app/src/main/jniLibs"
mkdir -p "${JNI_LIBS_PATH}/${ANDROID_ABI}"
cp cmake-out/extension/android/libexecutorch_llama_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/"
```
-## Build Java library
-The Java part of the ExecuTorch library can be built with gradlew:
-```
-pushd extension/android
-./gradlew build
-popd
-```
-In the android app, we set up the relative path to the built aar, so no further action is needed.
-
## Build Java app
1. Open Android Studio and select "Open an existing Android Studio project" to open examples/demo-apps/android/LlamaDemo.
2. Run the app (^R).
diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
index c1ce342968..c5d96bab1b 100644
--- a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
+++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
@@ -61,6 +61,7 @@
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
+ 0320439D2BB4AC6600050211 /* LLaMA-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "LLaMA-Info.plist"; sourceTree = ""; };
0324D6802BAACB6900DEF36F /* App.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = App.swift; sourceTree = ""; };
0324D6812BAACB6900DEF36F /* ContentView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; };
0324D6822BAACB6900DEF36F /* LogManager.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LogManager.swift; sourceTree = ""; };
@@ -73,6 +74,7 @@
0324D6942BAACB7000DEF36F /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; };
0324D6992BAACB7C00DEF36F /* LLaMARunner.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = LLaMARunner.h; sourceTree = ""; };
0324D69A2BAACB7C00DEF36F /* LLaMARunner.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = LLaMARunner.mm; sourceTree = ""; };
+ 035A5E942BB4B523001E0553 /* LLaMA.entitlements */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.entitlements; path = LLaMA.entitlements; sourceTree = ""; };
036CAF9D2BB1444500D6C2D5 /* LLaMA.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LLaMA.app; sourceTree = BUILT_PRODUCTS_DIR; };
03729ED52BB1F8DE00152F2E /* LLaMARunner.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLaMARunner.framework; sourceTree = BUILT_PRODUCTS_DIR; };
03729F072BB203B300152F2E /* runner.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = runner.cpp; sourceTree = ""; };
@@ -109,6 +111,14 @@
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
+ 0320439E2BB4AC6600050211 /* SupportingFiles */ = {
+ isa = PBXGroup;
+ children = (
+ 0320439D2BB4AC6600050211 /* LLaMA-Info.plist */,
+ );
+ path = SupportingFiles;
+ sourceTree = "";
+ };
0324D6892BAACB6900DEF36F /* Application */ = {
isa = PBXGroup;
children = (
@@ -129,6 +139,7 @@
isa = PBXGroup;
children = (
0324D6892BAACB6900DEF36F /* Application */,
+ 0320439E2BB4AC6600050211 /* SupportingFiles */,
);
path = LLaMA;
sourceTree = "";
@@ -174,12 +185,21 @@
children = (
0324D68A2BAACB6900DEF36F /* LLaMA */,
0324D6952BAACB7000DEF36F /* LLaMAAssets */,
+ 035A5E952BB4B523001E0553 /* LLaMAEntitlements */,
0324D69F2BAACB7C00DEF36F /* LLaMARunner */,
036CAF9D2BB1444500D6C2D5 /* LLaMA.app */,
03729ED52BB1F8DE00152F2E /* LLaMARunner.framework */,
);
sourceTree = "";
};
+ 035A5E952BB4B523001E0553 /* LLaMAEntitlements */ = {
+ isa = PBXGroup;
+ children = (
+ 035A5E942BB4B523001E0553 /* LLaMA.entitlements */,
+ );
+ path = LLaMAEntitlements;
+ sourceTree = "";
+ };
03729F062BB2035900152F2E /* runner */ = {
isa = PBXGroup;
children = (
@@ -229,9 +249,9 @@
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
- 032C016E2AC228E6002955E1 /* App */ = {
+ 032C016E2AC228E6002955E1 /* LLaMA */ = {
isa = PBXNativeTarget;
- buildConfigurationList = 032C017D2AC228E7002955E1 /* Build configuration list for PBXNativeTarget "App" */;
+ buildConfigurationList = 032C017D2AC228E7002955E1 /* Build configuration list for PBXNativeTarget "LLaMA" */;
buildPhases = (
032C016B2AC228E6002955E1 /* Sources */,
032C016C2AC228E6002955E1 /* Frameworks */,
@@ -243,7 +263,7 @@
dependencies = (
03729EDA2BB1F8DE00152F2E /* PBXTargetDependency */,
);
- name = App;
+ name = LLaMA;
packageProductDependencies = (
0395C6D22BB34ED10090705A /* coreml_backend */,
0395C6D62BB34ED10090705A /* mps_backend */,
@@ -310,7 +330,7 @@
projectDirPath = "";
projectRoot = "";
targets = (
- 032C016E2AC228E6002955E1 /* App */,
+ 032C016E2AC228E6002955E1 /* LLaMA */,
03729ED42BB1F8DE00152F2E /* LLaMARunner */,
);
};
@@ -512,13 +532,16 @@
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_ENTITLEMENTS = LLaMAEntitlements/LLaMA.entitlements;
CODE_SIGN_IDENTITY = "Apple Development";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = "";
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
+ INFOPLIST_FILE = "LLaMA/SupportingFiles/LLaMA-Info.plist";
INFOPLIST_KEY_CFBundleDisplayName = iLLaMA;
+ INFOPLIST_KEY_LSSupportsOpeningDocumentsInPlace = YES;
INFOPLIST_KEY_NSCameraUsageDescription = "";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
@@ -542,13 +565,16 @@
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_ENTITLEMENTS = LLaMAEntitlements/LLaMA.entitlements;
CODE_SIGN_IDENTITY = "Apple Development";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = "";
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
+ INFOPLIST_FILE = "LLaMA/SupportingFiles/LLaMA-Info.plist";
INFOPLIST_KEY_CFBundleDisplayName = iLLaMA;
+ INFOPLIST_KEY_LSSupportsOpeningDocumentsInPlace = YES;
INFOPLIST_KEY_NSCameraUsageDescription = "";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
@@ -652,7 +678,7 @@
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
- 032C017D2AC228E7002955E1 /* Build configuration list for PBXNativeTarget "App" */ = {
+ 032C017D2AC228E7002955E1 /* Build configuration list for PBXNativeTarget "LLaMA" */ = {
isa = XCConfigurationList;
buildConfigurations = (
032C017E2AC228E7002955E1 /* Debug */,
diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/App.xcscheme b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/LLaMA.xcscheme
similarity index 95%
rename from examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/App.xcscheme
rename to examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/LLaMA.xcscheme
index 68bdac5800..e02694c9f0 100644
--- a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/App.xcscheme
+++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/xcshareddata/xcschemes/LLaMA.xcscheme
@@ -16,7 +16,7 @@
BuildableIdentifier = "primary"
BlueprintIdentifier = "032C016E2AC228E6002955E1"
BuildableName = "LLaMA.app"
- BlueprintName = "App"
+ BlueprintName = "LLaMA"
ReferencedContainer = "container:LLaMA.xcodeproj">
@@ -45,7 +45,7 @@
BuildableIdentifier = "primary"
BlueprintIdentifier = "032C016E2AC228E6002955E1"
BuildableName = "LLaMA.app"
- BlueprintName = "App"
+ BlueprintName = "LLaMA"
ReferencedContainer = "container:LLaMA.xcodeproj">
@@ -61,7 +61,7 @@
BuildableIdentifier = "primary"
BlueprintIdentifier = "032C016E2AC228E6002955E1"
BuildableName = "LLaMA.app"
- BlueprintName = "App"
+ BlueprintName = "LLaMA"
ReferencedContainer = "container:LLaMA.xcodeproj">
diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/SupportingFiles/LLaMA-Info.plist b/examples/demo-apps/apple_ios/LLaMA/LLaMA/SupportingFiles/LLaMA-Info.plist
new file mode 100644
index 0000000000..ff579a6caf
--- /dev/null
+++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/SupportingFiles/LLaMA-Info.plist
@@ -0,0 +1,8 @@
+
+
+
+
+ UIFileSharingEnabled
+
+
+
diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMAEntitlements/LLaMA.entitlements b/examples/demo-apps/apple_ios/LLaMA/LLaMAEntitlements/LLaMA.entitlements
new file mode 100644
index 0000000000..99f471672d
--- /dev/null
+++ b/examples/demo-apps/apple_ios/LLaMA/LLaMAEntitlements/LLaMA.entitlements
@@ -0,0 +1,8 @@
+
+
+
+
+ com.apple.developer.kernel.increased-memory-limit
+
+
+
diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt
index 3ebe142d4f..7da4dabbc9 100644
--- a/examples/models/llama2/CMakeLists.txt
+++ b/examples/models/llama2/CMakeLists.txt
@@ -20,13 +20,15 @@ project(llama_runner)
option(EXECUTORCH_BUILD_OPTIMIZED "Build the optimized kernels" OFF)
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
# Can't set to 11 due to executor_runner.cpp make_unique
@@ -54,21 +56,15 @@ find_package(executorch CONFIG REQUIRED)
# llama_runner library
add_subdirectory(runner)
-set(link_options)
set(link_libraries)
if(EXECUTORCH_BUILD_OPTIMIZED)
- list(APPEND link_libraries optimized_native_cpu_ops_lib optimized_kernels portable_kernels)
- list(APPEND link_options
- "SHELL:LINKER:--whole-archive \
- $ \
- LINKER:--no-whole-archive")
+ list(APPEND link_libraries optimized_native_cpu_ops_lib optimized_kernels
+ portable_kernels)
+ target_link_options_shared_lib(optimized_native_cpu_ops_lib)
else()
list(APPEND link_libraries portable_ops_lib portable_kernels)
- list(APPEND link_options
- "SHELL:LINKER:--whole-archive \
- $ \
- LINKER:--no-whole-archive")
+ target_link_options_shared_lib(portable_ops_lib)
endif()
target_link_libraries(llama_main PUBLIC gflags llama_runner)
@@ -77,24 +73,21 @@ target_link_libraries(llama_main PUBLIC gflags llama_runner)
if(TARGET xnnpack_backend)
set(xnnpack_backend_libs xnnpack_backend XNNPACK pthreadpool cpuinfo)
list(APPEND link_libraries ${xnnpack_backend_libs})
- list(APPEND link_options
- "SHELL:LINKER:--whole-archive \
- $ \
- LINKER:--no-whole-archive")
+ target_link_options_shared_lib(xnnpack_backend)
endif()
# Vulkan backend
if(TARGET vulkan_backend)
list(APPEND link_libraries vulkan_backend)
- list(APPEND link_options
- "SHELL:LINKER:--whole-archive \
- $ \
- LINKER:--no-whole-archive")
+ target_link_options_shared_lib(vulkan_backend)
endif()
target_compile_options(llama_main PUBLIC ${_common_compile_options})
target_link_libraries(llama_main PUBLIC ${link_libraries})
-target_link_options(llama_main PUBLIC ${link_options})
+
+if(APPLE)
+ target_link_options_shared_lib(executorch)
+endif()
# Print all summary
executorch_print_configuration_summary()
diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py
index f625292c3d..3473391b64 100644
--- a/examples/models/llama2/builder.py
+++ b/examples/models/llama2/builder.py
@@ -283,7 +283,9 @@ def export_to_edge(
)
return self
- def to_backend(self, partitioner: Optional[Partitioner]) -> "LlamaEdgeManager":
+ def to_backend(
+ self, partitioners: Optional[List[Partitioner]]
+ ) -> "LlamaEdgeManager":
"""
Partition the model and lower to different backends. The signature is
aligned with the signature of `to_backend` method of EdgeManager.
@@ -291,18 +293,26 @@ def to_backend(self, partitioner: Optional[Partitioner]) -> "LlamaEdgeManager":
partitioner (Optional[Partitioner]): One or more
partitioner to be sent to EdgeManager.to_backend().
"""
- assert self.edge_manager is not None, "Need to run export_to_edge() first"
- if partitioner is None:
+ if partitioners is None:
logging.info("No partitioner provided, passing...")
else:
- self.edge_manager = self.edge_manager.to_backend(partitioner)
- if self.verbose:
- logging.info(
- print_delegated_graph(
- self.edge_manager.exported_program().graph_module
- )
- )
- logging.info(f"Applied partitioners: {partitioner}")
+ for partitioner in partitioners:
+ if partitioner is not None:
+ assert (
+ self.edge_manager is not None
+ ), "Need to run export_to_edge() first"
+ self.edge_manager = self.edge_manager.to_backend(partitioner)
+ if self.verbose:
+ logging.info(
+ print_delegated_graph(
+ self.edge_manager.exported_program().graph_module
+ )
+ )
+ logging.info(f"Applied partitioners: {partitioner}")
+ else:
+ logging.info("No partitioner provided, passing...")
+ continue
+
return self
def to_executorch(self) -> "LlamaEdgeManager":
diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py
index 4d3593a6e0..30ef1d786d 100644
--- a/examples/models/llama2/export_llama_lib.py
+++ b/examples/models/llama2/export_llama_lib.py
@@ -9,12 +9,13 @@
import argparse
import copy
import logging
+import os
import shlex
from dataclasses import dataclass
from functools import partial
from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Union
import pkg_resources
import torch
@@ -237,8 +238,12 @@ def quantize(
else:
torch_dtype = torch.float16
- if checkpoint_path is None:
- checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
+ assert checkpoint_path, "Need to specify a checkpoint"
+ assert os.path.isfile(
+ canonical_path(checkpoint_path)
+ ), f"{checkpoint_path} does not exist"
+ # if checkpoint_path is None:
+ # checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
if calibration_tasks is None:
calibration_tasks = ["wikitext"]
@@ -457,7 +462,9 @@ def build_args_parser() -> argparse.ArgumentParser:
return parser
-def canonical_path(path: str, *, dir: bool = False) -> str:
+def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
+
+ path = str(path)
if verbose_export():
print(f"creating canonical path for {path}")
@@ -602,9 +609,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
).export_to_edge(quantizers)
# to_backend
- partitioner = None
+ partitioners = []
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
- partitioner = XnnpackDynamicallyQuantizedPartitioner()
+ partitioners.append(XnnpackDynamicallyQuantizedPartitioner())
modelname = f"xnnpack_dq_{modelname}"
if args.xnnpack:
@@ -612,8 +619,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# 1. We need dynamically quantized partitioner for both pt2e_quantize options
# as well as "qmode 8da4w" which is also dynamic quantizes linear layers.
# 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
- partitioner = XnnpackDynamicallyQuantizedPartitioner()
- # partitioner = XnnpackPartitioner()
+ partitioners.append(XnnpackDynamicallyQuantizedPartitioner())
+ # partitioners.append(XnnpackPartitioner())
modelname = f"xnnpack_{modelname}"
if args.vulkan:
@@ -624,7 +631,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
args.quantization_mode is None
), "Vulkan backend does not support quantization at the moment"
- partitioner = VulkanPartitioner()
+ partitioners.append(VulkanPartitioner())
modelname = f"vulkan_{modelname}"
if args.mps:
@@ -643,7 +650,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
compile_specs = [CompileSpec("use_fp16", bytes([True]))]
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`.
- partitioner = MPSPartitioner(compile_specs)
+ partitioners.append(MPSPartitioner(compile_specs))
modelname = f"mps_{modelname}"
if args.coreml:
@@ -673,9 +680,11 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`
model_type=CoreMLBackend.MODEL_TYPE.MODEL,
)
- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`
- partitioner = CoreMLPartitioner(
- skip_ops_for_coreml_delegation=None, compile_specs=compile_specs
+ partitioners.append(
+ # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`
+ CoreMLPartitioner(
+ skip_ops_for_coreml_delegation=None, compile_specs=compile_specs
+ )
)
modelname = f"coreml_{modelname}"
@@ -707,18 +716,20 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
backend_options = generate_htp_compiler_spec(use_fp16=False)
- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
- partitioner = QnnPartitioner(
+ partitioners.append(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
- generate_qnn_executorch_compiler_spec(
- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
- soc_model=QcomChipset.SM8650, # default to SM8650
- backend_options=backend_options,
- debug=False,
- saver=False,
- ),
- skip_node_id_set={},
- skip_node_op_set={},
+ QnnPartitioner(
+ # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
+ generate_qnn_executorch_compiler_spec(
+ # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
+ soc_model=QcomChipset.SM8650, # default to SM8650
+ backend_options=backend_options,
+ debug=False,
+ saver=False,
+ ),
+ skip_node_id_set={},
+ skip_node_op_set={},
+ )
)
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
_transform(builder_exported_to_edge.export_program())
@@ -730,7 +741,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
logging.info("Generating etrecord")
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
- builder = builder_exported_to_edge.to_backend(partitioner).to_executorch()
+ builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
# Generate ETRecord
if edge_manager_copy:
@@ -741,7 +752,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
)
logging.info("Generated etrecord.bin")
else:
- builder = builder_exported_to_edge.to_backend(partitioner).to_executorch()
+ builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
diff --git a/examples/models/llama2/runner/CMakeLists.txt b/examples/models/llama2/runner/CMakeLists.txt
index a21995281d..75802f91f9 100644
--- a/examples/models/llama2/runner/CMakeLists.txt
+++ b/examples/models/llama2/runner/CMakeLists.txt
@@ -39,7 +39,7 @@ list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")
target_include_directories(extension_module
INTERFACE ${_common_include_directories})
-if(CMAKE_TOOLCHAIN_IOS OR CMAKE_TOOLCHAIN_ANDROID)
+if(CMAKE_TOOLCHAIN_IOS OR CMAKE_TOOLCHAIN_ANDROID OR APPLE)
# Building a share library on iOS requires code signing
# On Android we see duplicated registration when using shared lib
add_library(llama_runner STATIC ${_llama_runner__srcs})
diff --git a/examples/portable/custom_ops/CMakeLists.txt b/examples/portable/custom_ops/CMakeLists.txt
index e23e26993e..289fc4dda9 100644
--- a/examples/portable/custom_ops/CMakeLists.txt
+++ b/examples/portable/custom_ops/CMakeLists.txt
@@ -22,9 +22,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
@@ -37,6 +34,10 @@ endif()
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
# Let files say "include ".
diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt
index 54772f5c78..cff5db2a63 100644
--- a/examples/qualcomm/CMakeLists.txt
+++ b/examples/qualcomm/CMakeLists.txt
@@ -11,9 +11,6 @@ endif()
cmake_minimum_required(VERSION 3.19)
project(qualcomm_runner_example)
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -23,6 +20,13 @@ if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Debug)
endif()
@@ -41,7 +45,6 @@ set(_common_include_directories ${EXECUTORCH_ROOT}/..)
#
# The `__srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}.
#
-include(${EXECUTORCH_ROOT}/build/Utils.cmake)
set(EXECUTORCH_SRCS_FILE
"${CMAKE_CURRENT_BINARY_DIR}/../../executorch_srcs.cmake"
)
@@ -55,7 +58,6 @@ get_filename_component(EXECUTORCH_SOURCE_DIR
set(_qnn_executor_runner__srcs ${_executor_runner__srcs})
# portable_ops_lib
-include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
gen_selected_ops("" "" "ON")
generate_bindings_for_kernels(
FUNCTIONS_YAML ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml
diff --git a/examples/sdk/CMakeLists.txt b/examples/sdk/CMakeLists.txt
index 06e91ddbb5..d7ca7679e3 100644
--- a/examples/sdk/CMakeLists.txt
+++ b/examples/sdk/CMakeLists.txt
@@ -14,9 +14,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -29,6 +26,10 @@ endif()
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
# Let files say "include ".
diff --git a/examples/selective_build/CMakeLists.txt b/examples/selective_build/CMakeLists.txt
index 38859f0c67..2979118718 100644
--- a/examples/selective_build/CMakeLists.txt
+++ b/examples/selective_build/CMakeLists.txt
@@ -18,15 +18,16 @@
cmake_minimum_required(VERSION 3.19)
project(selective_build_example)
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
+
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
# Can't set to 11 due to executor_runner.cpp make_unique
diff --git a/examples/xtensa/CMakeLists.txt b/examples/xtensa/CMakeLists.txt
index c5187db75d..8c9a251a16 100644
--- a/examples/xtensa/CMakeLists.txt
+++ b/examples/xtensa/CMakeLists.txt
@@ -11,10 +11,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
# Set the project name.
project(xtensa_executorch_example)
@@ -23,6 +19,12 @@ if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Let files say "include ".
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
diff --git a/examples/xtensa/ops/CMakeLists.txt b/examples/xtensa/ops/CMakeLists.txt
index a848e3bc9a..215de49f20 100644
--- a/examples/xtensa/ops/CMakeLists.txt
+++ b/examples/xtensa/ops/CMakeLists.txt
@@ -11,10 +11,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
# Source root directory for pytorch.
if(NOT TORCH_ROOT)
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
@@ -23,6 +19,10 @@ endif()
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# ATen compliant ops that are needed to run this model.
set(_aten_ops__srcs
"${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp"
diff --git a/exir/backend/TARGETS b/exir/backend/TARGETS
index dd872dd2b2..70e235d3ad 100644
--- a/exir/backend/TARGETS
+++ b/exir/backend/TARGETS
@@ -106,6 +106,7 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
+ "fbsource//third-party/pypi/pandas:pandas",
"//caffe2:torch",
"//executorch/exir:lowered_backend_module",
],
diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS
index bcccdfed11..ecf03ceb61 100644
--- a/exir/backend/test/TARGETS
+++ b/exir/backend/test/TARGETS
@@ -269,6 +269,7 @@ python_unittest(
"test_utils.py",
],
deps = [
+ "fbsource//third-party/pypi/pandas:pandas",
":op_partitioner_demo",
"//caffe2:torch",
"//executorch/exir:lib",
diff --git a/exir/backend/test/test_utils.py b/exir/backend/test/test_utils.py
index fa57e8493f..098dd8e308 100644
--- a/exir/backend/test/test_utils.py
+++ b/exir/backend/test/test_utils.py
@@ -6,6 +6,8 @@
import unittest
+import pandas as pd
+
import torch
from executorch import exir
from executorch.exir import CaptureConfig, to_edge
@@ -13,7 +15,9 @@
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.backend.utils import (
+ DelegationBreakdown,
get_delegates,
+ get_delegation_info,
get_non_lowered_nodes,
is_identical_graph,
print_delegated_graph,
@@ -22,6 +26,7 @@
)
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
+from pandas.testing import assert_frame_equal
from torch.ao.quantization import get_default_qconfig # @manual
from torch.ao.quantization.backend_config.executorch import (
get_executorch_backend_config,
@@ -439,3 +444,65 @@ def forward(self, a, x, b):
graph_str,
"Expect to see the aten.mm in the delegated graph",
)
+
+ def test_get_delegation_info(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, a, x, b):
+ y = torch.mm(a, x)
+ z = y + b
+ a = z - a
+ y = torch.mm(a, x)
+ z = y + b
+ return z
+
+ m = Model()
+ inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
+ edge = to_edge(torch.export.export(m, inputs)).to_backend(
+ AddMulPartitionerDemo()
+ )
+ delegation_info = get_delegation_info(edge.exported_program().graph_module)
+
+ self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
+ self.assertEqual(delegation_info.num_delegated_nodes, 4)
+ self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
+ expected_delegation_by_op_dict = {
+ "aten_add_tensor": DelegationBreakdown(
+ op_type="aten_add_tensor", delegated=2, non_delegated=0
+ ),
+ "aten_mm_default": DelegationBreakdown(
+ op_type="aten_mm_default", delegated=2, non_delegated=0
+ ),
+ "aten_sub_tensor": DelegationBreakdown(
+ op_type="aten_sub_tensor", delegated=0, non_delegated=1
+ ),
+ "getitem": DelegationBreakdown(
+ op_type="getitem", delegated=0, non_delegated=2
+ ),
+ }
+ self.assertEqual(
+ delegation_info.delegation_by_operator, expected_delegation_by_op_dict
+ )
+
+ self.assertIn(
+ "Total delegated subgraphs",
+ delegation_info.get_summary(),
+ )
+
+ df = delegation_info.get_operator_delegation_dataframe()
+ expected_df = pd.DataFrame(
+ {
+ "op_type": [
+ "aten_add_tensor",
+ "aten_mm_default",
+ "aten_sub_tensor",
+ "getitem",
+ "Total",
+ ],
+ "occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
+ "occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
+ }
+ )
+ assert_frame_equal(expected_df, df)
diff --git a/exir/backend/utils.py b/exir/backend/utils.py
index 01906146e3..34ffab3f4b 100644
--- a/exir/backend/utils.py
+++ b/exir/backend/utils.py
@@ -6,10 +6,13 @@
import logging
import operator
+import re
from collections import defaultdict
+from dataclasses import asdict, dataclass
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+import pandas as pd
import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
@@ -27,6 +30,12 @@
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
+# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe()
+# which describes the summarized delegation information grouped by each operator type
+_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs"
+_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs"
+
+
log: logging.Logger = logging.getLogger(__name__)
@@ -280,6 +289,163 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
]
+@dataclass
+class DelegationBreakdown:
+ """
+ DelegationBreakdown contains the number of delegated and non-delegated nodes
+ of the operator type op_type.
+
+ Args:
+ delegated: The number of delegated nodes.
+ non_delegated: The number of non-delegated nodes.
+ """
+
+ op_type: str = ""
+ delegated: int = 0
+ non_delegated: int = 0
+
+
+@dataclass
+class DelegationInfo:
+ """
+ DelegationInfo contains information of a delegated graph module.
+
+ Args:
+ num_delegated_subgraphs: The number of delegated subgraphs.
+ num_delegated_nodes: The number of delegated nodes.
+ num_non_delegated_nodes: The number of non-delegated nodes.
+ delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
+ """
+
+ num_delegated_subgraphs: int
+ num_delegated_nodes: int
+ num_non_delegated_nodes: int
+ delegation_by_operator: Dict[str, DelegationBreakdown]
+
+ def get_summary(self) -> str:
+ """
+ Get a summary of the delegation information in string format.
+
+ Args:
+ None
+
+ Returns:
+ A string containing information of some class attributes for easy print-out.
+ """
+
+ # Assemble and return the summary string
+ summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
+ summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
+ summary_str += (
+ f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
+ )
+ return summary_str
+
+ def get_operator_delegation_dataframe(self) -> pd.DataFrame:
+ """
+ Get the delegation information grouped by operator type in a pandas DataFrame.
+
+ Args:
+ None
+
+ Returns:
+ Returns a pandas DataFrame containing the following columns:
+ - op_type: The operator type, with the last row being "Total".
+ - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
+ - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
+ With the last row being the total number of delegated and non-delegated occurrences of each op_type.
+ """
+
+ # Convert the dict to a dataframe
+ list_of_dicts = [
+ asdict(breakdown) for breakdown in self.delegation_by_operator.values()
+ ]
+ df = pd.DataFrame(list_of_dicts)
+ # Rename columns for better understandability
+ df = df.rename(
+ columns={
+ "delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS,
+ "non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS,
+ }
+ )
+ df = df.sort_values(by="op_type", ignore_index=True)
+
+ # Add a Total row at the bottom
+ total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum()
+ total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum()
+ df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
+
+ return df
+
+
+def get_delegation_info(
+ graph_module: torch.fx.GraphModule,
+) -> DelegationInfo:
+ """
+ Util function to get the delegation information of the given graph module.
+
+ Args:
+ graph_module: The lowered graph module to get the delegation information from.
+
+ Returns:
+ Return a DelegationInfo object containing the delegation information.
+ """
+
+ def _get_op_type(node_name: str) -> str:
+ # node_name is in format or _x in which x is an integer suffix.
+ return re.sub(r"_[\d]+$", "", node_name)
+
+ op_occurrences_dict = defaultdict(lambda: DelegationBreakdown())
+
+ def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
+ op_type = _get_op_type(node_name)
+ op_occurrences_dict[op_type].op_type = op_type
+ if delegated:
+ op_occurrences_dict[op_type].delegated += 1
+ else:
+ op_occurrences_dict[op_type].non_delegated += 1
+
+ delegated_subgraph_counter = 0
+
+ lowered_module_dict = {
+ node.name: getattr(graph_module, node.name)
+ for node in graph_module.graph.nodes
+ if node.op == "get_attr" and node.name.startswith("lowered_module_")
+ }
+
+ for node in graph_module.graph.nodes:
+ if (
+ node.op == "call_function"
+ and _get_op_type(node.name) != "executorch_call_delegate"
+ ):
+ # Non-delegated node
+ _insert_op_occurrences_dict(node_name=node.name, delegated=False)
+ # Check if the node is a lowered module
+ if node.op == "get_attr" and node.name.startswith("lowered_module_"):
+ lowered_module = lowered_module_dict[node.name]
+ delegated_subgraph_counter += 1
+ for node_in_lowered_module in lowered_module.original_module.graph.nodes:
+ if node_in_lowered_module.op == "call_function":
+ # Delegated node
+ _insert_op_occurrences_dict(
+ node_name=node_in_lowered_module.name, delegated=True
+ )
+
+ # Calculate the total number of delegated and non-delegated nodes
+ num_delegated_nodes = 0
+ num_non_delegated_nodes = 0
+ for value in op_occurrences_dict.values():
+ num_delegated_nodes += value.delegated
+ num_non_delegated_nodes += value.non_delegated
+
+ return DelegationInfo(
+ num_delegated_nodes=num_delegated_nodes,
+ num_non_delegated_nodes=num_non_delegated_nodes,
+ num_delegated_subgraphs=delegated_subgraph_counter,
+ delegation_by_operator=op_occurrences_dict,
+ )
+
+
def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
"""
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:
diff --git a/exir/delegate.py b/exir/delegate.py
index e516777e33..959bd4bb17 100644
--- a/exir/delegate.py
+++ b/exir/delegate.py
@@ -24,13 +24,9 @@
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
-# pyre-ignore
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
-# pyre-ignore
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
-executorch_call_delegate.fallthrough(torch._C.DispatchKey.BackendSelect)
-# pyre-ignore
executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
diff --git a/extension/android/.gitignore b/extension/android/.gitignore
new file mode 100644
index 0000000000..89d520b140
--- /dev/null
+++ b/extension/android/.gitignore
@@ -0,0 +1,5 @@
+local.properties
+.gradle
+.idea/*
+.externalNativeBuild
+build
diff --git a/extension/aot_util/CMakeLists.txt b/extension/aot_util/CMakeLists.txt
deleted file mode 100644
index dd3e9de21f..0000000000
--- a/extension/aot_util/CMakeLists.txt
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Please this file formatted by running:
-# ~~~
-# cmake-format --first-comment-is-literal=True CMakeLists.txt
-# ~~~
-
-cmake_minimum_required(VERSION 3.19)
-project(aot_util)
-include(../../build/Utils.cmake)
-
-if(NOT CMAKE_CXX_STANDARD)
- set(CMAKE_CXX_STANDARD 17)
-endif()
-
-if(NOT CMAKE_BUILD_TYPE)
- set(CMAKE_BUILD_TYPE Debug)
-endif()
-
-if(NOT EXECUTORCH_ROOT)
- set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../")
-endif()
-
-if(NOT BUCK2)
- set(BUCK2 buck2)
-endif()
-
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
-if(NOT EXECUTORCH_SRCS_FILE)
- # A file wasn't provided. Run a script to extract the source lists from the
- # buck2 build system and write them to a file we can include.
- #
- # NOTE: This will only happen once during cmake setup, so it will not re-run
- # if the buck2 targets change.
- message(STATUS "executorch: Generating source lists")
- set(EXECUTORCH_SRCS_FILE "${CMAKE_CURRENT_BINARY_DIR}/executorch_srcs.cmake")
- extract_sources(${EXECUTORCH_SRCS_FILE})
-endif()
-
-# This file defines the `___srcs` variables used below.
-message(STATUS "executorch: Using sources file ${EXECUTORCH_SRCS_FILE}")
-include(${EXECUTORCH_SRCS_FILE})
-
-
-# Ahead-of-time (AOT) utility library. Contains native code used by the
-# AOT lowering and delegation logic. Note that this library should build
-# independently of the runtime code, and as such, should not have
-# dependencies on runtime targets.
-find_package(Torch CONFIG REQUIRED)
-find_library(TORCH_PYTHON_LIBRARY torch_python
- PATHS "${TORCH_INSTALL_PREFIX}/lib")
-
-# Override compiler flags set in upper scope when included from the top-level
-# CMakeLists. ExecuTorch builds with -fno-exceptions and -fno-rtti, but we
-# need these for ATen.
-unset(CMAKE_CXX_FLAGS_RELEASE)
-
-list(TRANSFORM _extension_aot_util__srcs PREPEND "${EXECUTORCH_ROOT}/")
-add_library(aot_util ${_extension_aot_util__srcs})
-target_include_directories(aot_util PUBLIC ${TORCH_INCLUDE_DIRS})
-target_link_libraries(aot_util torch)
diff --git a/extension/aot_util/TARGETS b/extension/aot_util/TARGETS
deleted file mode 100644
index d2f94c3505..0000000000
--- a/extension/aot_util/TARGETS
+++ /dev/null
@@ -1,12 +0,0 @@
-load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
-
-oncall("executorch")
-
-runtime.cxx_library(
- name = "aot_util",
- srcs = ["convert_to_qc4w.cpp"],
- visibility = [
- "//executorch/...",
- ],
- external_deps = ["libtorch"],
-)
diff --git a/extension/aot_util/convert_to_qc4w.cpp b/extension/aot_util/convert_to_qc4w.cpp
deleted file mode 100644
index 8fb3293397..0000000000
--- a/extension/aot_util/convert_to_qc4w.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-#include
-#include
-
-at::Tensor convert_to_qc4w(at::Tensor x) {
- std::vector sizes = x.sizes().vec();
- TORCH_CHECK(sizes.size() == 2, "Expecting 2D tensor");
- TORCH_CHECK(sizes[1] % 2 == 0);
- TORCH_CHECK(
- x.options().dtype() == at::kByte, "Input tensor must be of type uint8.");
- sizes[1] = sizes[1] / 2;
- at::Tensor output = at::empty(sizes, x.options().dtype());
- uint8_t* x_ptr = x.data_ptr();
- uint8_t* output_ptr = output.data_ptr();
- for (int i = 0; i < output.numel(); ++i) {
- int32_t input_i = i * 2;
- int32_t input_i_plus_1 = i * 2 + 1;
- output_ptr[i] = (x_ptr[input_i_plus_1] << 4) | (x_ptr[input_i]);
- }
- return output;
-}
-
-TORCH_LIBRARY_FRAGMENT(xnnpack, m) {
- m.def("convert_to_qc4w", &convert_to_qc4w);
-}
diff --git a/extension/apple/ExecuTorch/Exported/module.modulemap b/extension/apple/ExecuTorch/Exported/module.modulemap
index 7c094ce35a..6ac771cd17 100644
--- a/extension/apple/ExecuTorch/Exported/module.modulemap
+++ b/extension/apple/ExecuTorch/Exported/module.modulemap
@@ -1,5 +1,5 @@
module ExecuTorch {
- umbrella header "ExecuTorch.h"
+ umbrella header "ExecuTorch/ExecuTorch.h"
export *
}
\ No newline at end of file
diff --git a/extension/module/CMakeLists.txt b/extension/module/CMakeLists.txt
index e0d7ccc250..e36cfa3760 100644
--- a/extension/module/CMakeLists.txt
+++ b/extension/module/CMakeLists.txt
@@ -17,7 +17,7 @@ if(NOT EXECUTORCH_ROOT)
endif()
list(TRANSFORM _extension_module__srcs PREPEND "${EXECUTORCH_ROOT}/")
-if(CMAKE_TOOLCHAIN_IOS OR CMAKE_TOOLCHAIN_ANDROID)
+if(CMAKE_TOOLCHAIN_IOS OR CMAKE_TOOLCHAIN_ANDROID OR APPLE)
# Building a share library on iOS requires code signing
# On Android we see duplicated registration when using shared lib
add_library(extension_module STATIC ${_extension_module__srcs})
diff --git a/install_requirements.sh b/install_requirements.sh
index 05e36d2484..27f6d85c13 100755
--- a/install_requirements.sh
+++ b/install_requirements.sh
@@ -7,14 +7,14 @@
# Install required python dependencies for developing
# Dependencies are defined in .pyproject.toml
-if [[ -z $BUCK ]];
-then
- BUCK=buck2
-fi
-
if [[ -z $PYTHON_EXECUTABLE ]];
then
- PYTHON_EXECUTABLE=python3
+ if [[ -z $CONDA_DEFAULT_ENV ]] || [[ $CONDA_DEFAULT_ENV == "base" ]];
+ then
+ PYTHON_EXECUTABLE=python3
+ else
+ PYTHON_EXECUTABLE=python
+ fi
fi
# Parse options.
diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt
index d363515260..a33c32dd95 100644
--- a/kernels/optimized/CMakeLists.txt
+++ b/kernels/optimized/CMakeLists.txt
@@ -16,9 +16,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -33,6 +30,10 @@ set(_common_compile_options -Wno-deprecated-declarations)
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
# Executorch (for runtime). Here select all ops in optimized.yaml
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/optimized-oss.yaml")
diff --git a/kernels/portable/CMakeLists.txt b/kernels/portable/CMakeLists.txt
index c5bc9637b0..ddb9f01a40 100644
--- a/kernels/portable/CMakeLists.txt
+++ b/kernels/portable/CMakeLists.txt
@@ -16,9 +16,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -32,6 +29,11 @@ set(_common_compile_options -Wno-deprecated-declarations)
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Portable kernel sources TODO(larryliu0820): use buck2 to gather the sources
file(GLOB_RECURSE _portable_kernels__srcs
"${CMAKE_CURRENT_SOURCE_DIR}/cpu/*.cpp")
diff --git a/kernels/quantized/CMakeLists.txt b/kernels/quantized/CMakeLists.txt
index 06242246b4..7be9e73827 100644
--- a/kernels/quantized/CMakeLists.txt
+++ b/kernels/quantized/CMakeLists.txt
@@ -15,9 +15,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -31,6 +28,11 @@ set(_common_compile_options -Wno-deprecated-declarations)
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
# Quantized ops kernel sources TODO(larryliu0820): use buck2 to gather the
# sources
list(TRANSFORM _quantized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
diff --git a/scripts/release/apply-release-changes.sh b/scripts/release/apply-release-changes.sh
index 20be02feeb..fde86c97d8 100755
--- a/scripts/release/apply-release-changes.sh
+++ b/scripts/release/apply-release-changes.sh
@@ -26,7 +26,7 @@ RELEASE_BRANCH="release/${RELEASE_VERSION}"
if git ls-remote --exit-code origin ${RELEASE_BRANCH} >/dev/null 2>&1; then
echo "Check out to Release Branch '${RELEASE_BRANCH}'"
- git checkout -b ${RELEASE_BRANCH}
+ git checkout ${RELEASE_BRANCH}
else
echo "Error: Remote branch '${RELEASE_BRANCH}' not found. Please run 'cut-release-branch.sh' first."
exit 1
diff --git a/sdk/CMakeLists.txt b/sdk/CMakeLists.txt
index 7f9e5f5a04..8738c347f6 100644
--- a/sdk/CMakeLists.txt
+++ b/sdk/CMakeLists.txt
@@ -17,10 +17,6 @@ if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
-if(NOT PYTHON_EXECUTABLE)
- set(PYTHON_EXECUTABLE python3)
-endif()
-
if(NOT FLATCC_EXECUTABLE)
set(FLATCC_EXECUTABLE flatcc)
endif()
@@ -30,6 +26,12 @@ if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
endif()
+include(${EXECUTORCH_ROOT}/build/Utils.cmake)
+
+if(NOT PYTHON_EXECUTABLE)
+ resolve_python_executable()
+endif()
+
if(NOT FLATC_EXECUTABLE)
set(FLATC_EXECUTABLE flatc)
endif()
diff --git a/setup.py b/setup.py
index 6616a09f13..3606a317e0 100644
--- a/setup.py
+++ b/setup.py
@@ -231,11 +231,6 @@ def run(self):
}
ext_modules = []
-if os.environ.get("EXECUTORCH_BUILD_AOT_UTIL", "ON") == "ON":
- ext_modules.append(
- CMakeExtension("executorch.extension.aot_util.aot_util", "extension/aot_util")
- )
-
if os.environ.get("EXECUTORCH_BUILD_PYBIND", "OFF") == "ON":
ext_modules.append(CMakeExtension("executorch.extension.pybindings.portable_lib"))