Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagmt-google authored Jun 19, 2024
2 parents 83560f5 + df85f34 commit febf87a
Show file tree
Hide file tree
Showing 714 changed files with 20,559 additions and 14,543 deletions.
4 changes: 3 additions & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ pip_install coloredlogs packaging
pip_install onnxruntime==1.18
pip_install onnx==1.16.0
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps
pip_install onnxscript==0.1.0.dev20240523 --no-deps
pip_install onnxscript==0.1.0.dev20240613 --no-deps
# required by onnxscript
pip_install ml_dtypes

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
24 changes: 1 addition & 23 deletions .ci/pytorch/common_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function install_torchrec_and_fbgemm() {

function clone_pytorch_xla() {
if [[ ! -d ./xla ]]; then
git clone --recursive -b r2.4 https://github.com/pytorch/xla.git
git clone --recursive --quiet https://github.com/pytorch/xla.git
pushd xla
# pin the xla hash so that we don't get broken by changes to xla
git checkout "$(cat ../.github/ci_commit_pins/xla.txt)"
Expand All @@ -188,28 +188,6 @@ function clone_pytorch_xla() {
fi
}

function checkout_install_torchdeploy() {
local commit
commit=$(get_pinned_commit multipy)
pushd ..
git clone --recurse-submodules https://github.com/pytorch/multipy.git
pushd multipy
git checkout "${commit}"
python multipy/runtime/example/generate_examples.py
BUILD_CUDA_TESTS=1 pip install -e .
popd
popd
}

function test_torch_deploy(){
pushd ..
pushd multipy
./multipy/runtime/build/test_deploy
./multipy/runtime/build/test_deploy_gpu
popd
popd
}

function checkout_install_torchbench() {
local commit
commit=$(get_pinned_commit torchbench)
Expand Down
3 changes: 0 additions & 3 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1242,9 +1242,6 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then
if [[ "${SHARD_NUMBER}" == 1 ]]; then
test_rpc
fi
elif [[ "$TEST_CONFIG" == deploy ]]; then
checkout_install_torchdeploy
test_torch_deploy
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
Expand Down
4 changes: 2 additions & 2 deletions .circleci/scripts/binary_populate_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:
# Only linux Python < 3.13 are supported wheels for triton
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'"
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" ]]; then
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
Expand All @@ -89,7 +89,7 @@ fi
# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}"
if [[ -n "$PYTORCH_BUILD_VERSION" ]]; then
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}"
fi
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/torchbench.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d6015d42d9a1834bc7595c4bd6852562fb80b30b
0dab1dd97709096e8129f8a08115ee83f64f2194
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
r2.4
6f0b61e5d782913a0fc7743812f2a8e522189111
2 changes: 0 additions & 2 deletions .github/merge_rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
- third_party/onnx
- caffe2/python/onnx/**
approved_by:
- BowenBao
- justinchuby
- liqunfu
- shubhambhokare1
- thiagocrepaldi
- titaiwangms
- wschin
- xadupre
Expand Down
1 change: 1 addition & 0 deletions .github/pytorch-probot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ retryable_workflows:
- windows-binary
labeler_config: labeler.yml
label_to_label_config: label_to_label.yml
mergebot: True
114 changes: 102 additions & 12 deletions .github/scripts/cherry_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import json
import os
import re
from typing import Any, Optional
from typing import Any, cast, Dict, List, Optional

from urllib.error import HTTPError

from github_utils import gh_fetch_url, gh_post_pr_comment
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels

from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge import get_pr_commit_sha, GitHubPR
Expand All @@ -19,6 +19,7 @@
"critical",
"fixnewfeature",
}
RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")


def parse_args() -> Any:
Expand Down Expand Up @@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
return commit_sha if pr.is_closed() else None


def get_release_version(onto_branch: str) -> Optional[str]:
"""
Return the release version if the target branch is a release branch
"""
m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
return m.group("version") if m else ""


def get_tracker_issues(
org: str, project: str, onto_branch: str
) -> List[Dict[str, Any]]:
"""
Find the tracker issue from the repo. The tracker issue needs to have the title
like [VERSION] Release Tracker following the convention on PyTorch
"""
version = get_release_version(onto_branch)
if not version:
return []

tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
if not tracker_issues:
return []

# Figure out the tracker issue from the list by looking at the title
return [issue for issue in tracker_issues if version in issue.get("title", "")]


def cherry_pick(
github_actor: str,
repo: GitRepo,
Expand All @@ -77,17 +105,49 @@ def cherry_pick(
)

try:
org, project = repo.gh_owner_and_name()

cherry_pick_pr = ""
if not dry_run:
org, project = repo.gh_owner_and_name()
cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)

msg = f"The cherry pick PR is at {cherry_pick_pr}"
if fixes:
msg += f" and it is linked with issue {fixes}"
elif classification in REQUIRES_ISSUE:
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue"
tracker_issues_comments = []
tracker_issues = get_tracker_issues(org, project, onto_branch)
for issue in tracker_issues:
issue_number = int(str(issue.get("number", "0")))
if not issue_number:
continue

res = cast(
Dict[str, Any],
post_tracker_issue_comment(
org,
project,
issue_number,
pr.pr_num,
cherry_pick_pr,
classification,
fixes,
dry_run,
),
)

comment_url = res.get("html_url", "")
if comment_url:
tracker_issues_comments.append(comment_url)

post_comment(org, project, pr.pr_num, msg)
msg = f"The cherry pick PR is at {cherry_pick_pr}"
if fixes:
msg += f" and it is linked with issue {fixes}."
elif classification in REQUIRES_ISSUE:
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."

if tracker_issues_comments:
msg += " The following tracker issues are updated:\n"
for tracker_issues_comment in tracker_issues_comments:
msg += f"* {tracker_issues_comment}\n"

post_pr_comment(org, project, pr.pr_num, msg, dry_run)

finally:
if current_branch:
Expand Down Expand Up @@ -159,7 +219,9 @@ def submit_pr(
raise RuntimeError(msg) from error


def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
def post_pr_comment(
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
"""
Post a comment on the PR itself to point to the cherry picking PR when success
or print the error when failure
Expand All @@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
comment = "\n".join(
(f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
)
gh_post_pr_comment(org, project, pr_num, comment)
return gh_post_pr_comment(org, project, pr_num, comment, dry_run)


def post_tracker_issue_comment(
org: str,
project: str,
issue_num: int,
pr_num: int,
cherry_pick_pr: str,
classification: str,
fixes: str,
dry_run: bool = False,
) -> List[Dict[str, Any]]:
"""
Post a comment on the tracker issue (if any) to record the cherry pick
"""
comment = "\n".join(
(
"Link to landed trunk PR (if applicable):",
f"* https://github.com/{org}/{project}/pull/{pr_num}",
"",
"Link to release branch PR:",
f"* {cherry_pick_pr}",
"",
"Criteria Category:",
" - ".join((classification.capitalize(), fixes.capitalize())),
)
)
return gh_post_pr_comment(org, project, issue_num, comment, dry_run)


def main() -> None:
Expand Down Expand Up @@ -214,7 +304,7 @@ def main() -> None:

except RuntimeError as error:
if not args.dry_run:
post_comment(org, project, pr_num, str(error))
post_pr_comment(org, project, pr_num, str(error))
else:
raise error

Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/filter_test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
}

# The link to the published list of disabled jobs
DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=tIl0Qo224T_NDVw0dtG4hU1cZJM97inV"
DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json"
# and unstable jobs
UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=GPyRZRsOo26Gfk_WjAoNNxEMGXkIxIes"
UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json"

# Some constants used to handle disabled and unstable jobs
JOB_NAME_SEP = "/"
Expand Down
47 changes: 27 additions & 20 deletions .github/scripts/get_workflow_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from argparse import ArgumentParser
from typing import Any
from typing import Any, Tuple

from github import Auth, Github
from github.Issue import Issue
Expand All @@ -9,6 +9,8 @@
WORKFLOW_LABEL_META = "" # use meta runners
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
LABEL_TYPE_KEY = "label_type"
MESSAGE_KEY = "message"
MESSAGE = "" # Debug message to return to the caller


def parse_args() -> Any:
Expand Down Expand Up @@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool:
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}


def get_workflow_type(issue: Issue, username: str) -> str:
def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]:
try:
user_list = issue.get_comments()[0].body.split()

if user_list[0] == "!":
print("LF Workflows are disabled for everyone. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = "LF Workflows are disabled for everyone. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
elif user_list[0] == "*":
print("LF Workflows are enabled for everyone. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = "LF Workflows are enabled for everyone. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
elif username in user_list:
print(f"LF Workflows are enabled for {username}. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
else:
print(f"LF Workflows are disabled for {username}. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
except Exception as e:
print(
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
)
return WORKFLOW_LABEL_META
MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
return WORKFLOW_LABEL_META, MESSAGE


def main() -> None:
args = parse_args()

if is_exception_branch(args.github_branch):
print(f"Exception branch: '{args.github_branch}', using meta runners")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners",
}
else:
try:
gh = get_gh_client(args.github_token)
# The default issue we use - https://github.com/pytorch/test-infra/issues/5132
issue = get_issue(gh, args.github_repo, args.github_issue)

output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)}
label_type, message = get_workflow_type(issue, args.github_user)
output = {
LABEL_TYPE_KEY: label_type,
MESSAGE_KEY: message,
}
except Exception as e:
print(f"Failed to get issue. Falling back to meta runners. Exception: {e}")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}",
}

json_output = json.dumps(output)
print(json_output)
Expand Down
9 changes: 9 additions & 0 deletions .github/scripts/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
)
else:
raise


def gh_query_issues_by_labels(
org: str, repo: str, labels: List[str], state: str = "open"
) -> List[Dict[str, Any]]:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
return gh_fetch_json(
url, method="GET", params={"labels": ",".join(labels), "state": state}
)
3 changes: 3 additions & 0 deletions .github/scripts/test_trymerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def mock_gh_get_info() -> Any:
return {
"closed": False,
"isCrossRepository": False,
"headRefName": "foo",
"baseRefName": "bar",
"baseRepository": {"defaultBranchRef": {"name": "bar"}},
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
"changedFiles": 0,
}
Expand Down
Loading

0 comments on commit febf87a

Please sign in to comment.